Skip to content

Commit fd195b1

Browse files
authored
[V1][P/D] Local attention optimization for NIXL (#18170)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent fabe89b commit fd195b1

File tree

1 file changed

+90
-11
lines changed

1 file changed

+90
-11
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
9696
self.connector_worker: Optional[NixlConnectorWorker] = None
9797
elif role == KVConnectorRole.WORKER:
9898
self.connector_scheduler = None
99-
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
99+
self.connector_worker = NixlConnectorWorker(
100+
vllm_config, str(self.engine_id))
100101

101102
############################################################
102103
# Scheduler Side Methods
@@ -302,7 +303,7 @@ def request_finished(
302303
class NixlConnectorWorker:
303304
"""Implementation of Worker side methods"""
304305

305-
def __init__(self, engine_id: str):
306+
def __init__(self, vllm_config: VllmConfig, engine_id: str):
306307
if NixlWrapper is None:
307308
logger.error("NIXL is not available")
308309
raise RuntimeError("NIXL is not available")
@@ -329,6 +330,7 @@ def __init__(self, engine_id: str):
329330
# Number of NIXL regions. Currently one region per cache
330331
# (so 1 per layer for MLA, otherwise 2 per layer)
331332
self.num_regions = 0
333+
self.num_layers = 0
332334

333335
# nixl_prepped_dlist_handle (int).
334336
self.src_xfer_side_handle: int = 0
@@ -355,6 +357,14 @@ def __init__(self, engine_id: str):
355357
# Background thread for establishing new connections.
356358
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
357359

360+
self.vllm_config = vllm_config
361+
self.block_size = vllm_config.cache_config.block_size
362+
363+
# TODO(mgoin): remove this once we have hybrid memory allocator
364+
# Optimization for models with local attention (Llama 4)
365+
# List of block window sizes for each layer for local attention
366+
self.block_window_per_layer: list[Optional[int]] = []
367+
358368
@staticmethod
359369
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
360370
ready_event: threading.Event, rank: int):
@@ -465,6 +475,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
465475
kv_caches_base_addr.append(base_addr)
466476
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
467477
self.num_regions = len(caches_data)
478+
self.num_layers = len(self.kv_caches.keys())
479+
480+
# TODO(mgoin): remove this once we have hybrid memory allocator
481+
# Optimization for models with local attention (Llama 4)
482+
if self.vllm_config.model_config.hf_config.model_type == "llama4":
483+
from transformers import Llama4TextConfig
484+
assert isinstance(self.vllm_config.model_config.hf_text_config,
485+
Llama4TextConfig)
486+
llama4_config = self.vllm_config.model_config.hf_text_config
487+
no_rope_layers = llama4_config.no_rope_layers
488+
chunk_size = llama4_config.attention_chunk_size
489+
chunk_block_size = math.ceil(chunk_size / self.block_size)
490+
for layer_idx in range(self.num_layers):
491+
# no_rope_layers[layer_idx] == 0 means NoPE (global)
492+
# Any other value means RoPE (local chunked)
493+
is_local_attention = no_rope_layers[layer_idx] != 0
494+
block_window = chunk_block_size if is_local_attention else None
495+
self.block_window_per_layer.append(block_window)
496+
logger.debug("Llama 4 block window per layer mapping: %s",
497+
self.block_window_per_layer)
498+
assert len(self.block_window_per_layer) == self.num_layers
468499

469500
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
470501
logger.debug("Registering descs: %s", caches_data)
@@ -699,10 +730,39 @@ def _read_blocks(
699730
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
700731

701732
# Get descs ids.
702-
remote_block_descs_ids = self._get_block_descs_ids(
703-
dst_engine_id, remote_block_ids)
704-
local_block_descs_ids = self._get_block_descs_ids(
705-
self.engine_id, local_block_ids)
733+
local_block_descs_ids: list[int] = []
734+
remote_block_descs_ids: list[int] = []
735+
if not self.block_window_per_layer:
736+
# Default case: assume global attention
737+
remote_block_descs_ids = self._get_block_descs_ids(
738+
dst_engine_id, remote_block_ids)
739+
local_block_descs_ids = self._get_block_descs_ids(
740+
self.engine_id, local_block_ids)
741+
else:
742+
# TODO(mgoin): remove this once we have hybrid memory allocator
743+
# Optimization for models with local attention (Llama 4)
744+
for layer_idx, block_window in enumerate(
745+
self.block_window_per_layer):
746+
# For each layer:
747+
if block_window is None:
748+
# If not chunked, we just use the
749+
# full block lists (global attention)
750+
layer_local_block_ids = local_block_ids
751+
layer_remote_block_ids = remote_block_ids
752+
else:
753+
# If chunked, get the last block_window blocks
754+
layer_local_block_ids = local_block_ids[-block_window:]
755+
layer_remote_block_ids = remote_block_ids[-block_window:]
756+
757+
# Get descs ids for the layer.
758+
layer_local_desc_ids = self._get_block_descs_ids(
759+
self.engine_id, layer_local_block_ids, layer_idx)
760+
layer_remote_desc_ids = self._get_block_descs_ids(
761+
dst_engine_id, layer_remote_block_ids, layer_idx)
762+
763+
local_block_descs_ids.extend(layer_local_desc_ids)
764+
remote_block_descs_ids.extend(layer_remote_desc_ids)
765+
706766
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
707767

708768
# Prepare transfer with Nixl.
@@ -721,12 +781,31 @@ def _read_blocks(
721781
# Use handle to check completion in future step().
722782
self._recving_transfers[request_id].append(handle)
723783

724-
def _get_block_descs_ids(self, engine_id: str,
725-
block_ids: list[int]) -> list[int]:
726-
"""Get the descs ids for a set of block ids."""
784+
def _get_block_descs_ids(self,
785+
engine_id: str,
786+
block_ids: list[int],
787+
layer_idx: Optional[int] = None) -> list[int]:
788+
"""
789+
Get the descs ids for a set of block ids.
790+
If layer_idx is provided, we use the region_ids for the given layer.
791+
Otherwise, we use all regions.
792+
"""
793+
794+
if layer_idx is None:
795+
region_ids = range(self.num_regions)
796+
else:
797+
assert layer_idx < self.num_layers
798+
if self.num_layers < self.num_regions:
799+
# If we have more regions than layers, we assume that
800+
# the regions are organized as [K0, V0, K1, V1, ...]
801+
# and we select K_i and V_i
802+
assert 2 * self.num_layers == self.num_regions
803+
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
804+
else:
805+
# Otherwise, we assume we have MLA and select i-th layer
806+
assert self.num_layers == self.num_regions
807+
region_ids = range(layer_idx, layer_idx + 1)
727808

728-
# range(1) for MLA, range(2) otherwise.
729-
region_ids = range(self.num_regions)
730809
num_blocks = self.dst_num_blocks[engine_id]
731810

732811
# Compute the desc ids for each block.

0 commit comments

Comments
 (0)