99from vllm .attention .layer import Attention
1010from vllm .config import _BATCH_SIZES_TO_CAPTURE , CacheConfig , VllmConfig
1111from vllm .distributed import get_tensor_model_parallel_world_size
12+ from vllm .distributed .parallel_state import get_pp_group
1213from vllm .model_executor .layers .fused_moe import FusedMoE
1314from vllm .model_executor .layers .layernorm import RMSNorm
1415from vllm .model_executor .layers .linear import (QKVParallelLinear ,
2526 MambaCacheParams )
2627from vllm .model_executor .sampling_metadata import SamplingMetadata
2728from vllm .sequence import IntermediateTensors
29+ from vllm .utils import LayerBlockType
2830
29- from .interfaces import HasInnerState , SupportsLoRA
30- from .utils import maybe_prefix
31+ from .interfaces import HasInnerState , IsHybrid , SupportsLoRA , SupportsPP
32+ from .utils import (is_pp_missing_parameter ,
33+ make_empty_intermediate_tensors_factory , make_layers ,
34+ maybe_prefix )
3135
3236KVCache = Tuple [torch .Tensor , torch .Tensor ]
3337
@@ -281,16 +285,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281285 org_num_embeddings = config .vocab_size ,
282286 )
283287
284- decoder_layers = []
285- for i in range (config .num_hidden_layers ):
286- layer_class = ALL_DECODER_LAYER_TYPES [config .layers_block_type [i ]]
287- decoder_layers .append (
288- layer_class (config ,
289- layer_idx = i ,
290- cache_config = cache_config ,
291- quant_config = quant_config ,
292- prefix = f"{ prefix } .layers.{ i } " ))
293- self .layers = nn .ModuleList (decoder_layers )
288+ def get_layer (prefix : str ):
289+ layer_idx = int (prefix .rsplit ("." , 1 )[1 ])
290+ layer_class = ALL_DECODER_LAYER_TYPES [
291+ config .layers_block_type [layer_idx ]]
292+ return layer_class (
293+ config ,
294+ layer_idx ,
295+ cache_config ,
296+ quant_config = quant_config ,
297+ prefix = prefix ,
298+ )
299+
300+ self .start_layer , self .end_layer , self .layers = make_layers (
301+ config .num_hidden_layers , get_layer , prefix = f"{ prefix } .layers" )
302+ self .make_empty_intermediate_tensors = (
303+ make_empty_intermediate_tensors_factory (
304+ ["hidden_states" , "residual" ], config .hidden_size ))
305+
294306 self .final_layernorm = RMSNorm (config .hidden_size ,
295307 eps = config .rms_norm_eps )
296308
@@ -304,26 +316,34 @@ def forward(
304316 kv_caches : List [torch .Tensor ],
305317 attn_metadata : AttentionMetadata ,
306318 mamba_cache_params : MambaCacheParams ,
319+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
307320 inputs_embeds : Optional [torch .Tensor ] = None ,
308321 ) -> torch .Tensor :
309- if inputs_embeds is not None :
310- hidden_states = inputs_embeds
322+ if get_pp_group ().is_first_rank :
323+ if inputs_embeds is not None :
324+ hidden_states = inputs_embeds
325+ else :
326+ hidden_states = self .get_input_embeddings (input_ids )
327+ residual = None
311328 else :
312- hidden_states = self .get_input_embeddings (input_ids )
313- residual = None
314- for i in range (len (self .layers )):
329+ assert intermediate_tensors is not None
330+ hidden_states = intermediate_tensors ["hidden_states" ]
331+ residual = intermediate_tensors ["residual" ]
332+
333+ kv_cache_index = 0
334+ mamba_cache_index = 0
335+ for i in range (self .start_layer , self .end_layer ):
315336 layer = self .layers [i ]
316337 kv_cache = None
317338 layer_mamba_cache_params = None
318339 if isinstance (layer , JambaAttentionDecoderLayer ):
319- kv_cache = kv_caches [( i - self . config . attn_layer_offset ) //
320- self . config . attn_layer_period ]
340+ kv_cache = kv_caches [kv_cache_index ]
341+ kv_cache_index += 1
321342 if isinstance (layer , JambaMambaDecoderLayer ):
322- current_state_layer = i - (1 +
323- (i - self .config .attn_layer_offset )
324- // self .config .attn_layer_period )
343+ current_state_layer = mamba_cache_index
325344 layer_mamba_cache_params = mamba_cache_params .at_layer_idx (
326345 current_state_layer )
346+ mamba_cache_index += 1
327347
328348 hidden_states , residual = layer (
329349 positions = positions ,
@@ -332,11 +352,17 @@ def forward(
332352 attn_metadata = attn_metadata ,
333353 residual = residual ,
334354 mamba_cache_params = layer_mamba_cache_params )
355+ if not get_pp_group ().is_last_rank :
356+ return IntermediateTensors ({
357+ "hidden_states" : hidden_states ,
358+ "residual" : residual
359+ })
335360 hidden_states , _ = self .final_layernorm (hidden_states , residual )
336361 return hidden_states
337362
338363
339- class JambaForCausalLM (nn .Module , HasInnerState , SupportsLoRA ):
364+ class JambaForCausalLM (nn .Module , HasInnerState , SupportsLoRA , SupportsPP ,
365+ IsHybrid ):
340366 packed_modules_mapping = {
341367 "qkv_proj" : [
342368 "q_proj" ,
@@ -368,6 +394,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
368394
369395 super ().__init__ ()
370396 self .config = config
397+ self .vllm_config = vllm_config
398+ self .model_config = vllm_config .model_config
371399 self .scheduler_config = scheduler_config
372400 self .model = JambaModel (vllm_config = vllm_config ,
373401 prefix = maybe_prefix (prefix , "model" ))
@@ -390,6 +418,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
390418 config .vocab_size )
391419 self .sampler = get_sampler ()
392420
421+ self .make_empty_intermediate_tensors = (
422+ self .model .make_empty_intermediate_tensors )
423+
393424 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
394425 return self .model .get_input_embeddings (input_ids )
395426
@@ -406,10 +437,8 @@ def forward(self,
406437 self .scheduler_config .max_num_seqs ) if self .scheduler_config
407438 else max (_BATCH_SIZES_TO_CAPTURE ) + 2 )
408439
409- layers_type = self .config .layers_block_type
410- num_mamba_layers = sum (
411- [layer_type == "mamba" for layer_type in layers_type ])
412-
440+ num_mamba_layers = self .model_config .get_num_layers_by_block_type (
441+ self .vllm_config .parallel_config , LayerBlockType .mamba )
413442 self .mamba_cache = MambaCacheManager (
414443 self .lm_head .weight .dtype , num_mamba_layers , max_batch_size ,
415444 * self ._get_mamba_cache_shape ())
@@ -423,7 +452,7 @@ def forward(self,
423452 state_indices_tensor )
424453 hidden_states = self .model (input_ids , positions , kv_caches ,
425454 attn_metadata , mamba_cache_params ,
426- inputs_embeds )
455+ intermediate_tensors , inputs_embeds )
427456 return hidden_states
428457
429458 def copy_inputs_before_cuda_graphs (self , input_buffers , ** kwargs ):
@@ -504,8 +533,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
504533 continue
505534 name = name .replace (weight_name , param_name )
506535 # Skip loading extra bias for GPTQ models.
536+
507537 if name .endswith (".bias" ) and name not in params_dict :
508538 continue
539+ # Skip layers on other devices.
540+ if is_pp_missing_parameter (name , self ):
541+ continue
509542 param = params_dict [name ]
510543 weight_loader = param .weight_loader
511544 weight_loader (param , loaded_weight , shard_id )
@@ -520,6 +553,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
520553 if weight_name not in name :
521554 continue
522555
556+ if is_pp_missing_parameter (name , self ):
557+ continue
523558 name = name .replace (weight_name , param_name )
524559 param = params_dict [name ]
525560 weight_loader = param .weight_loader
@@ -533,6 +568,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
533568 # Skip loading extra bias for GPTQ models.
534569 if name .endswith (".bias" ) and name not in params_dict :
535570 continue
571+ if is_pp_missing_parameter (name , self ):
572+ continue
536573
537574 param = params_dict [name ]
538575 weight_loader = getattr (param , "weight_loader" ,
0 commit comments