Skip to content

[V1][P/D] Local attention optimization for NIXL #18170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 90 additions & 11 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
self.connector_worker = NixlConnectorWorker(
vllm_config, str(self.engine_id))

############################################################
# Scheduler Side Methods
Expand Down Expand Up @@ -302,7 +303,7 @@ def request_finished(
class NixlConnectorWorker:
"""Implementation of Worker side methods"""

def __init__(self, engine_id: str):
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
Expand All @@ -329,6 +330,7 @@ def __init__(self, engine_id: str):
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0
self.num_layers = 0

# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
Expand All @@ -355,6 +357,14 @@ def __init__(self, engine_id: str):
# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None

self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[Optional[int]] = []

@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
Expand Down Expand Up @@ -465,6 +475,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())

# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers

descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
Expand Down Expand Up @@ -699,10 +730,39 @@ def _read_blocks(
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]

# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
for layer_idx, block_window in enumerate(
self.block_window_per_layer):
# For each layer:
if block_window is None:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids = local_block_ids
layer_remote_block_ids = remote_block_ids
else:
# If chunked, get the last block_window blocks
layer_local_block_ids = local_block_ids[-block_window:]
layer_remote_block_ids = remote_block_ids[-block_window:]

# Get descs ids for the layer.
layer_local_desc_ids = self._get_block_descs_ids(
self.engine_id, layer_local_block_ids, layer_idx)
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx)

local_block_descs_ids.extend(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids)

assert len(local_block_descs_ids) == len(remote_block_descs_ids)

# Prepare transfer with Nixl.
Expand All @@ -721,12 +781,31 @@ def _read_blocks(
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)

def _get_block_descs_ids(self, engine_id: str,
block_ids: list[int]) -> list[int]:
"""Get the descs ids for a set of block ids."""
def _get_block_descs_ids(self,
engine_id: str,
block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""

if layer_idx is None:
region_ids = range(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
# If we have more regions than layers, we assume that
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
else:
# Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1)

# range(1) for MLA, range(2) otherwise.
region_ids = range(self.num_regions)
num_blocks = self.dst_num_blocks[engine_id]

# Compute the desc ids for each block.
Expand Down