@@ -96,7 +96,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
96
96
self .connector_worker : Optional [NixlConnectorWorker ] = None
97
97
elif role == KVConnectorRole .WORKER :
98
98
self .connector_scheduler = None
99
- self .connector_worker = NixlConnectorWorker (str (self .engine_id ))
99
+ self .connector_worker = NixlConnectorWorker (
100
+ vllm_config , str (self .engine_id ))
100
101
101
102
############################################################
102
103
# Scheduler Side Methods
@@ -302,7 +303,7 @@ def request_finished(
302
303
class NixlConnectorWorker :
303
304
"""Implementation of Worker side methods"""
304
305
305
- def __init__ (self , engine_id : str ):
306
+ def __init__ (self , vllm_config : VllmConfig , engine_id : str ):
306
307
if NixlWrapper is None :
307
308
logger .error ("NIXL is not available" )
308
309
raise RuntimeError ("NIXL is not available" )
@@ -329,6 +330,7 @@ def __init__(self, engine_id: str):
329
330
# Number of NIXL regions. Currently one region per cache
330
331
# (so 1 per layer for MLA, otherwise 2 per layer)
331
332
self .num_regions = 0
333
+ self .num_layers = 0
332
334
333
335
# nixl_prepped_dlist_handle (int).
334
336
self .src_xfer_side_handle : int = 0
@@ -355,6 +357,14 @@ def __init__(self, engine_id: str):
355
357
# Background thread for establishing new connections.
356
358
self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
357
359
360
+ self .vllm_config = vllm_config
361
+ self .block_size = vllm_config .cache_config .block_size
362
+
363
+ # TODO(mgoin): remove this once we have hybrid memory allocator
364
+ # Optimization for models with local attention (Llama 4)
365
+ # List of block window sizes for each layer for local attention
366
+ self .block_window_per_layer : list [Optional [int ]] = []
367
+
358
368
@staticmethod
359
369
def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
360
370
ready_event : threading .Event , rank : int ):
@@ -465,6 +475,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
465
475
kv_caches_base_addr .append (base_addr )
466
476
self .kv_caches_base_addr [self .engine_id ] = kv_caches_base_addr
467
477
self .num_regions = len (caches_data )
478
+ self .num_layers = len (self .kv_caches .keys ())
479
+
480
+ # TODO(mgoin): remove this once we have hybrid memory allocator
481
+ # Optimization for models with local attention (Llama 4)
482
+ if self .vllm_config .model_config .hf_config .model_type == "llama4" :
483
+ from transformers import Llama4TextConfig
484
+ assert isinstance (self .vllm_config .model_config .hf_text_config ,
485
+ Llama4TextConfig )
486
+ llama4_config = self .vllm_config .model_config .hf_text_config
487
+ no_rope_layers = llama4_config .no_rope_layers
488
+ chunk_size = llama4_config .attention_chunk_size
489
+ chunk_block_size = math .ceil (chunk_size / self .block_size )
490
+ for layer_idx in range (self .num_layers ):
491
+ # no_rope_layers[layer_idx] == 0 means NoPE (global)
492
+ # Any other value means RoPE (local chunked)
493
+ is_local_attention = no_rope_layers [layer_idx ] != 0
494
+ block_window = chunk_block_size if is_local_attention else None
495
+ self .block_window_per_layer .append (block_window )
496
+ logger .debug ("Llama 4 block window per layer mapping: %s" ,
497
+ self .block_window_per_layer )
498
+ assert len (self .block_window_per_layer ) == self .num_layers
468
499
469
500
descs = self .nixl_wrapper .get_reg_descs (caches_data , "VRAM" )
470
501
logger .debug ("Registering descs: %s" , caches_data )
@@ -699,10 +730,39 @@ def _read_blocks(
699
730
remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
700
731
701
732
# Get descs ids.
702
- remote_block_descs_ids = self ._get_block_descs_ids (
703
- dst_engine_id , remote_block_ids )
704
- local_block_descs_ids = self ._get_block_descs_ids (
705
- self .engine_id , local_block_ids )
733
+ local_block_descs_ids : list [int ] = []
734
+ remote_block_descs_ids : list [int ] = []
735
+ if not self .block_window_per_layer :
736
+ # Default case: assume global attention
737
+ remote_block_descs_ids = self ._get_block_descs_ids (
738
+ dst_engine_id , remote_block_ids )
739
+ local_block_descs_ids = self ._get_block_descs_ids (
740
+ self .engine_id , local_block_ids )
741
+ else :
742
+ # TODO(mgoin): remove this once we have hybrid memory allocator
743
+ # Optimization for models with local attention (Llama 4)
744
+ for layer_idx , block_window in enumerate (
745
+ self .block_window_per_layer ):
746
+ # For each layer:
747
+ if block_window is None :
748
+ # If not chunked, we just use the
749
+ # full block lists (global attention)
750
+ layer_local_block_ids = local_block_ids
751
+ layer_remote_block_ids = remote_block_ids
752
+ else :
753
+ # If chunked, get the last block_window blocks
754
+ layer_local_block_ids = local_block_ids [- block_window :]
755
+ layer_remote_block_ids = remote_block_ids [- block_window :]
756
+
757
+ # Get descs ids for the layer.
758
+ layer_local_desc_ids = self ._get_block_descs_ids (
759
+ self .engine_id , layer_local_block_ids , layer_idx )
760
+ layer_remote_desc_ids = self ._get_block_descs_ids (
761
+ dst_engine_id , layer_remote_block_ids , layer_idx )
762
+
763
+ local_block_descs_ids .extend (layer_local_desc_ids )
764
+ remote_block_descs_ids .extend (layer_remote_desc_ids )
765
+
706
766
assert len (local_block_descs_ids ) == len (remote_block_descs_ids )
707
767
708
768
# Prepare transfer with Nixl.
@@ -721,12 +781,31 @@ def _read_blocks(
721
781
# Use handle to check completion in future step().
722
782
self ._recving_transfers [request_id ].append (handle )
723
783
724
- def _get_block_descs_ids (self , engine_id : str ,
725
- block_ids : list [int ]) -> list [int ]:
726
- """Get the descs ids for a set of block ids."""
784
+ def _get_block_descs_ids (self ,
785
+ engine_id : str ,
786
+ block_ids : list [int ],
787
+ layer_idx : Optional [int ] = None ) -> list [int ]:
788
+ """
789
+ Get the descs ids for a set of block ids.
790
+ If layer_idx is provided, we use the region_ids for the given layer.
791
+ Otherwise, we use all regions.
792
+ """
793
+
794
+ if layer_idx is None :
795
+ region_ids = range (self .num_regions )
796
+ else :
797
+ assert layer_idx < self .num_layers
798
+ if self .num_layers < self .num_regions :
799
+ # If we have more regions than layers, we assume that
800
+ # the regions are organized as [K0, V0, K1, V1, ...]
801
+ # and we select K_i and V_i
802
+ assert 2 * self .num_layers == self .num_regions
803
+ region_ids = range (2 * layer_idx , 2 * layer_idx + 2 )
804
+ else :
805
+ # Otherwise, we assume we have MLA and select i-th layer
806
+ assert self .num_layers == self .num_regions
807
+ region_ids = range (layer_idx , layer_idx + 1 )
727
808
728
- # range(1) for MLA, range(2) otherwise.
729
- region_ids = range (self .num_regions )
730
809
num_blocks = self .dst_num_blocks [engine_id ]
731
810
732
811
# Compute the desc ids for each block.
0 commit comments