-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[P/D][NixlConnector] Enable FlashInfer backend #19090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,15 @@ | |
import zmq | ||
|
||
from vllm import envs | ||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend | ||
from vllm.config import VllmConfig | ||
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) | ||
from vllm.distributed.parallel_state import ( | ||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, | ||
get_tp_group) | ||
from vllm.logger import init_logger | ||
from vllm.platforms import _Backend | ||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down | ||
from vllm.v1.core.sched.output import SchedulerOutput | ||
from vllm.v1.request import RequestStatus | ||
|
@@ -57,6 +59,7 @@ class NixlAgentMetadata( | |
num_blocks: int | ||
tp_size: int | ||
block_len: int | ||
attn_backend_name: str | ||
|
||
|
||
@dataclass | ||
|
@@ -384,11 +387,25 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |
|
||
self.vllm_config = vllm_config | ||
self.block_size = vllm_config.cache_config.block_size | ||
self.model_config = vllm_config.model_config | ||
self.cache_config = vllm_config.cache_config | ||
|
||
# TODO(mgoin): remove this once we have hybrid memory allocator | ||
# Optimization for models with local attention (Llama 4) | ||
# List of block window sizes for each layer for local attention | ||
self.block_window_per_layer: list[Optional[int]] = [] | ||
self.use_mla = self.model_config.use_mla | ||
|
||
backend = get_attn_backend(self.model_config.get_head_size(), | ||
self.model_config.dtype, | ||
self.cache_config.cache_dtype, | ||
self.block_size, | ||
self.model_config.is_attention_free, | ||
use_mla=self.use_mla) | ||
self.backend_name = backend.get_name() | ||
attn_backend = backend_name_to_enum(self.backend_name) | ||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 | ||
logger.debug("Detected attention backend %s", self.backend_name) | ||
|
||
self._tp_size: dict[str, int] = {self.engine_id: self.world_size} | ||
# With heterogeneous TP, P must wait for all assigned D TP workers to | ||
|
@@ -472,37 +489,44 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
kv_elem_size = first_kv_cache.element_size() | ||
|
||
# TODO(tms): Find a more robust way to detect and handle MLA | ||
self.use_mla = len(first_kv_cache.shape) == 3 | ||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected | ||
# KV memory layout is HND, as opposed to the default NHD. Note that it | ||
# will only affects the strides. For MLA instead, we make require no | ||
# such thing and resort to the standard layout. | ||
if self.use_mla: | ||
use_mla = len(first_kv_cache.shape) == 3 | ||
assert use_mla == self.use_mla | ||
|
||
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check | ||
# once it goes live, as a single kv layout is expected for xfers. | ||
if use_mla: | ||
# MLA case. | ||
self.num_blocks = first_kv_cache.shape[0] | ||
block_rank = 2 # [block_size, latent_dim] | ||
block_shape = first_kv_cache.shape[-block_rank:] | ||
block_size, kv_latent_dim = block_shape | ||
self.slot_size_bytes = kv_elem_size * kv_latent_dim | ||
else: | ||
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim] | ||
self.num_blocks = first_kv_cache.shape[1] | ||
block_rank = 3 # [block_size, kv_heads, head_dim] | ||
# [2 (k and v), num_blocks, ...] | ||
if self._use_flashinfer: | ||
# FlashInfer swaps 2<->num_blocks dimensions. | ||
self.num_blocks = first_kv_cache.shape[0] | ||
block_rank = 4 # [2, block_size, kv_heads, head_dim] | ||
else: | ||
self.num_blocks = first_kv_cache.shape[1] | ||
block_rank = 3 # [block_size, kv_heads, head_dim] | ||
block_shape = first_kv_cache.shape[-block_rank:] | ||
block_size, n_kv_heads, head_dim = block_shape | ||
block_size, n_kv_heads, head_dim = block_shape[-3:] | ||
# head size in bytes. | ||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim | ||
assert block_size == self.block_size | ||
# TODO(tms): self.block_len needs to be per-layer for sliding window, | ||
# hybrid attn, etc | ||
# block size in bytes | ||
self.block_len = kv_elem_size * math.prod(block_shape) | ||
|
||
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", | ||
self.use_mla, first_kv_cache.shape) | ||
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, | ||
block_shape) | ||
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) | ||
logger.info( | ||
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Registering KV_Caches: use_mla: %s, num_blocks: %s, " | ||
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, | ||
self.num_blocks, block_shape, first_kv_cache.shape) | ||
self.dst_num_blocks[self.engine_id] = self.num_blocks | ||
self.kv_caches = kv_caches | ||
kv_caches_base_addr = [] | ||
|
@@ -514,9 +538,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
# are non-contiguous (it's not locally guaranteed that they will be) | ||
# Disadvantage is that the encoded NixlAgentMetadata is now larger | ||
# (roughly 8KB vs 5KB). | ||
# Conversely for FlashInfer, K and V are transferred in the same tensor | ||
# to better exploit the memory layout (ie num_blocks is the first dim). | ||
for cache_or_caches in kv_caches.values(): | ||
# Normalize to always be a list of caches | ||
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches | ||
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ | ||
else cache_or_caches | ||
for cache in cache_list: | ||
base_addr = cache.data_ptr() | ||
region_len = self.num_blocks * self.block_len | ||
|
@@ -581,7 +608,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], | ||
num_blocks=self.num_blocks, | ||
tp_size=self.world_size, | ||
block_len=self.block_len) | ||
block_len=self.block_len, | ||
attn_backend_name=self.backend_name) | ||
ready_event = threading.Event() | ||
self._nixl_handshake_listener_t = threading.Thread( | ||
target=self._nixl_handshake_listener, | ||
|
@@ -641,6 +669,10 @@ def add_remote_agent(self, | |
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size | ||
else: | ||
self._tp_size[engine_id] = nixl_agent_meta.tp_size | ||
# We may eventually enable this after asserting equality in cache | ||
# layout and close outputs. | ||
assert nixl_agent_meta.attn_backend_name == self.backend_name | ||
|
||
self._remote_agents[engine_id][ | ||
remote_tp_rank] = self.nixl_wrapper.add_remote_agent( | ||
nixl_agent_meta.agent_metadata) | ||
|
@@ -659,13 +691,16 @@ def add_remote_agent(self, | |
else: | ||
remote_block_size = nixl_agent_meta.block_len // ( | ||
self.slot_size_bytes * tp_ratio) | ||
if self._use_flashinfer: | ||
# Account for joint KV in FlashInfer. | ||
remote_block_size //= 2 | ||
|
||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( | ||
"Remote P worker KV layer cache must be of shape [2, N, " | ||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." | ||
) | ||
|
||
assert self.block_size == remote_block_size, "Remote P worker with " | ||
assert self.block_size == remote_block_size, "Remote P worker with " \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. o/w assert message prints a tab in console |
||
"different block size is not supported" | ||
|
||
assert self.num_blocks >= nixl_agent_meta.num_blocks | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could just replace self.use_mla at this point, just wanted to make sure the two match for review