44from typing import Iterable , List , Optional , Tuple
55
66import torch
7- from transformers import JambaConfig
8- from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
9- from mamba_ssm .ops .selective_scan_interface import selective_scan_fn
10- from mamba_ssm .ops .triton .selective_state_update import selective_state_update
117from torch import nn
12- from torch .nn .parameter import Parameter
138
9+ from vllm .model_executor .layers .activation import SiluAndMul
1410from vllm .attention .backends .abstract import AttentionMetadata
1511from vllm .attention .layer import Attention
12+
13+ from transformers import JambaConfig
14+ from torch .nn .parameter import Parameter
1615from vllm .config import LoRAConfig
1716from vllm .distributed import (get_tensor_model_parallel_rank ,
1817 get_tensor_model_parallel_world_size ,
3332from vllm .model_executor .utils import set_weight_attrs
3433from vllm .model_executor .model_loader .weight_utils import default_weight_loader
3534from vllm .sequence import SamplerOutput
35+ from mamba_ssm .ops .selective_scan_interface import selective_scan_fn
36+ from mamba_ssm .ops .triton .selective_state_update import selective_state_update
37+ from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
3638
3739KVCache = Tuple [torch .Tensor , torch .Tensor ]
3840
@@ -43,7 +45,6 @@ class MambaCacheParams:
4345 ssm_state : torch .Tensor = torch .Tensor ()
4446
4547
46-
4748# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
4849class JambaMambaMixer (nn .Module ):
4950 """
@@ -124,28 +125,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
124125 input_is_parallel = True ,
125126 )
126127 self .activation = config .hidden_act
127- self .apply_inner_layernorms = config .mamba_inner_layernorms
128-
129- if self .apply_inner_layernorms :
130- self .dt_layernorm = RMSNorm (self .time_step_rank ,
131- eps = config .rms_norm_eps )
132- self .B_layernorm = RMSNorm (self .ssm_state_size ,
133- eps = config .rms_norm_eps )
134- self .C_layernorm = RMSNorm (self .ssm_state_size ,
135- eps = config .rms_norm_eps )
136- else :
137- self .dt_layernorm = None
138- self .B_layernorm = None
139- self .C_layernorm = None
140-
141- def _apply_layernorms (self , dt , B , C ):
142- if self .dt_layernorm is not None :
143- dt = self .dt_layernorm .forward (dt .contiguous ())
144- if self .B_layernorm is not None :
145- B = self .B_layernorm .forward (B .contiguous ())
146- if self .C_layernorm is not None :
147- C = self .C_layernorm .forward (C .contiguous ())
148- return dt , B , C
128+
129+ self .dt_layernorm = RMSNorm (self .time_step_rank , eps = config .rms_norm_eps )
130+ self .b_layernorm = RMSNorm (self .ssm_state_size , eps = config .rms_norm_eps )
131+ self .c_layernorm = RMSNorm (self .ssm_state_size , eps = config .rms_norm_eps )
149132
150133 def mamba_forward (self ,
151134 hidden_states : torch .Tensor ,
@@ -189,7 +172,9 @@ def mamba_forward(self,
189172 [self .time_step_rank , self .ssm_state_size , self .ssm_state_size ],
190173 dim = - 1 ,
191174 )
192- time_step , B , C = self ._apply_layernorms (time_step , B , C )
175+ time_step = self .dt_layernorm (time_step .contiguous ())
176+ B = self .b_layernorm (B .contiguous ())
177+ C = self .c_layernorm (C .contiguous ())
193178
194179 discrete_time_step = self .dt_proj (time_step )[0 ].transpose (1 , 2 )
195180 # 3.c perform the recurrence y ← SSM(A, B, C)(x)
@@ -275,6 +260,36 @@ def forward(
275260 return hidden_states
276261
277262
263+ class JambaMLP (nn .Module ):
264+ def __init__ (
265+ self ,
266+ config : JambaConfig ,
267+ quant_config : Optional [QuantizationConfig ] = None ,
268+ ) -> None :
269+ super ().__init__ ()
270+ hidden_size = config .hidden_size
271+ intermediate_size = config .intermediate_size
272+ hidden_act = config .hidden_act
273+ self .gate_up_proj = MergedColumnParallelLinear (
274+ hidden_size , [intermediate_size ] * 2 ,
275+ bias = False ,
276+ quant_config = quant_config )
277+ self .down_proj = RowParallelLinear (intermediate_size ,
278+ hidden_size ,
279+ bias = False ,
280+ quant_config = quant_config )
281+ if hidden_act != "silu" :
282+ raise ValueError (f"Unsupported activation: { hidden_act } . "
283+ "Only silu is supported for now." )
284+ self .act_fn = SiluAndMul ()
285+
286+ def forward (self , x ):
287+ gate_up , _ = self .gate_up_proj (x )
288+ x = self .act_fn (gate_up )
289+ x , _ = self .down_proj (x )
290+ return x
291+
292+
278293class JambaMoE (nn .Module ):
279294 """A tensor-parallel MoE implementation for Mixtral that shards each expert
280295 across all ranks.
@@ -285,33 +300,27 @@ class JambaMoE(nn.Module):
285300 """
286301
287302 def __init__ (
288- self ,
289- num_experts : int ,
290- top_k : int ,
291- hidden_size : int ,
292- intermediate_size : int ,
293- params_dtype : Optional [torch .dtype ] = None ,
294- tp_size : Optional [int ] = None ,
303+ self ,
304+ config : JambaConfig ,
305+ params_dtype : Optional [torch .dtype ] = None ,
306+ tp_size : Optional [int ] = None ,
307+ quant_config : Optional [QuantizationConfig ] = None ,
295308 ):
296309 super ().__init__ ()
297310 self .tp_size = tp_size or get_tensor_model_parallel_world_size ()
298- self .num_total_experts = num_experts
299- self .top_k = top_k
300- self .hidden_size = hidden_size
301- self .intermediate_size = intermediate_size // self .tp_size
311+ self .num_total_experts = config . num_experts
312+ self .top_k = config . num_experts_per_tok
313+ self .hidden_size = config . hidden_size
314+ self .intermediate_size = config . intermediate_size // self .tp_size
302315
303316 if params_dtype is None :
304317 params_dtype = torch .get_default_dtype ()
305318 self .params_dtype = params_dtype
306319
307- if self .num_total_experts > 1 :
308- # init expert router iff this layer has multiple experts
309- self .router = ReplicatedLinear (
310- self .hidden_size ,
311- self .num_total_experts ,
312- bias = False ,
313- params_dtype = self .params_dtype ,
314- )
320+ self .router = ReplicatedLinear (self .hidden_size ,
321+ self .num_total_experts ,
322+ bias = False ,
323+ params_dtype = self .params_dtype )
315324
316325 self .ws = nn .Parameter (
317326 torch .empty (
@@ -366,14 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366375 num_tokens , hidden_size = hidden_states .shape
367376 hidden_states = hidden_states .view (- 1 , self .hidden_size )
368377 # router_logits: (batch * sequence_length, n_experts)
369- if self .num_total_experts > 1 :
370- router_logits , _ = self .router (hidden_states )
371- else :
372- router_logits = torch .ones (
373- [hidden_states .shape [0 ], 1 ],
374- device = hidden_states .device ,
375- dtype = hidden_states .dtype ,
376- )
378+ router_logits , _ = self .router (hidden_states )
377379
378380 final_hidden_states = fused_moe (
379381 hidden_states ,
@@ -394,28 +396,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394396
395397
396398class JambaMambaDecoderLayer (nn .Module ):
397-
398399 def __init__ (
399- self ,
400- config : JambaConfig ,
401- actual_num_experts : int ,
402- actual_num_experts_per_tok : int ,
403- layer_idx : int ,
400+ self , config : JambaConfig , layer_idx : int , quant_config : Optional [QuantizationConfig ] = None
404401 ) -> None :
405402 super ().__init__ ()
406403 self .layer_idx = layer_idx
407404 self .config = config
408405 self .mamba = JambaMambaMixer (config , layer_idx )
409- self .moe = JambaMoE (
410- num_experts = actual_num_experts ,
411- top_k = actual_num_experts_per_tok ,
412- hidden_size = config .hidden_size ,
413- intermediate_size = config .intermediate_size ,
414- )
415- self .input_layernorm = RMSNorm (config .hidden_size ,
416- eps = config .rms_norm_eps )
417- self .pre_moe_layernorm = RMSNorm (config .hidden_size ,
418- eps = config .rms_norm_eps )
406+
407+ num_experts = config .layers_num_experts [layer_idx ]
408+ ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
409+ self .feed_forward = ffn_layer_class (config , quant_config )
410+ self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
411+ self .pre_ff_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
419412
420413 def forward (
421414 self ,
@@ -436,20 +429,15 @@ def forward(
436429 hidden_states = self .mamba (hidden_states , attn_metadata , conv_state ,
437430 ssm_state )
438431 # Fully Connected
439- hidden_states , residual = self .pre_moe_layernorm (
440- hidden_states , residual )
441- hidden_states = self .moe (hidden_states )
432+ hidden_states , residual = self .pre_ff_layernorm (hidden_states , residual )
433+ hidden_states = self .feed_forward (hidden_states )
442434 return hidden_states , residual
443435
444436
445437class JambaAttentionDecoderLayer (nn .Module ):
446438
447439 def __init__ (
448- self ,
449- config : JambaConfig ,
450- actual_num_experts : int ,
451- actual_num_experts_per_tok : int ,
452- quant_config : Optional [QuantizationConfig ] = None ,
440+ self , config : JambaConfig , layer_idx : int , quant_config : Optional [QuantizationConfig ] = None ,
453441 ) -> None :
454442 super ().__init__ ()
455443 self .hidden_size = config .hidden_size
@@ -494,16 +482,11 @@ def __init__(
494482 sliding_window = self .sliding_window ,
495483 )
496484
497- self .moe = JambaMoE (
498- num_experts = actual_num_experts ,
499- top_k = actual_num_experts_per_tok ,
500- hidden_size = config .hidden_size ,
501- intermediate_size = config .intermediate_size ,
502- )
503- self .input_layernorm = RMSNorm (config .hidden_size ,
504- eps = config .rms_norm_eps )
505- self .pre_moe_layernorm = RMSNorm (config .hidden_size ,
506- eps = config .rms_norm_eps )
485+ num_experts = config .layers_num_experts [layer_idx ]
486+ ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
487+ self .feed_forward = ffn_layer_class (config , quant_config )
488+ self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
489+ self .pre_ff_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
507490
508491 def self_attention (
509492 self ,
@@ -542,12 +525,14 @@ def forward(
542525 attn_metadata = attn_metadata ,
543526 )
544527 # Fully Connected
545- hidden_states , residual = self .pre_moe_layernorm (
546- hidden_states , residual )
547- hidden_states = self .moe (hidden_states )
528+ hidden_states , residual = self .pre_ff_layernorm (hidden_states , residual )
529+ hidden_states = self .feed_forward (hidden_states )
548530 return hidden_states , residual
549531
550532
533+ ALL_DECODER_LAYER_TYPES = {"attention" : JambaAttentionDecoderLayer , "mamba" : JambaMambaDecoderLayer }
534+
535+
551536class JambaModel (nn .Module ):
552537
553538 def __init__ (
@@ -570,40 +555,12 @@ def __init__(
570555 org_num_embeddings = config .vocab_size ,
571556 )
572557
573- # init each model layer, decide if it's mamba/attention and
574- # has experts and pass it down
575-
576- module_list = []
558+ decoder_layers = []
577559 for i in range (config .num_hidden_layers ):
578- is_attn = ((i - self .config .attn_layer_offset ) %
579- self .config .attn_layer_period == 0 )
580- is_expert = ((i - self .config .expert_layer_offset ) %
581- self .config .expert_layer_period == 0 )
582-
583- actual_num_experts = config .num_experts if is_expert else 1
584- actual_num_experts_per_tok = config .num_experts_per_tok \
585- if is_expert else 1
586-
587- if is_attn :
588- module_list .append (
589- JambaAttentionDecoderLayer (
590- config ,
591- actual_num_experts = actual_num_experts ,
592- actual_num_experts_per_tok = actual_num_experts_per_tok ,
593- quant_config = quant_config
594- ))
595- else :
596- module_list .append (
597- JambaMambaDecoderLayer (
598- config ,
599- actual_num_experts = actual_num_experts ,
600- actual_num_experts_per_tok = actual_num_experts_per_tok ,
601- layer_idx = i ,
602- ))
603-
604- self .layers = nn .ModuleList (module_list )
605- self .final_layernorm = RMSNorm (config .hidden_size ,
606- eps = config .rms_norm_eps )
560+ layer_class = ALL_DECODER_LAYER_TYPES [config .layers_block_type [i ]]
561+ decoder_layers .append (layer_class (config , layer_idx = i , quant_config = quant_config ))
562+ self .layers = nn .ModuleList (decoder_layers )
563+ self .final_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
607564
608565 def forward (
609566 self ,
@@ -732,6 +689,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
732689 ("qkv_proj" , "q_proj" , "q" ),
733690 ("qkv_proj" , "k_proj" , "k" ),
734691 ("qkv_proj" , "v_proj" , "v" ),
692+ ("gate_up_proj" , "gate_proj" , 0 ),
693+ ("gate_up_proj" , "up_proj" , 1 ),
735694 ]
736695
737696 expert_params_mapping = [
@@ -758,6 +717,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
758717 for param_name , weight_name , shard_id in stacked_params_mapping :
759718 if weight_name not in name :
760719 continue
720+ if 'experts' in name :
721+ continue
761722 name = name .replace (weight_name , param_name )
762723 # Skip loading extra bias for GPTQ models.
763724 if name .endswith (".bias" ) and name not in params_dict :
0 commit comments