@@ -34,15 +34,20 @@ def __init__(
3434        vllm_config : VllmConfig ,
3535        prefix : str  =  "" ,
3636        config : Optional [LlamaConfig ] =  None ,
37+         layer_idx : int  =  0 ,
3738    ) ->  None :
3839        super ().__init__ (vllm_config , prefix = prefix , config = config )
3940
4041        config  =  config  or  vllm_config .model_config .hf_config 
4142        quant_config  =  self .get_quant_config (vllm_config )
4243
44+         # First layer uses 2*hidden_size (embeds + hidden_states concatenated) 
45+         # Subsequent layers use hidden_size (only hidden_states, no embeds) 
46+         qkv_input_size  =  2  *  self .hidden_size  if  layer_idx  ==  0  else  self .hidden_size 
47+ 
4348        # override qkv 
4449        self .self_attn .qkv_proj  =  QKVParallelLinear (
45-             2   *   self . hidden_size ,
50+             qkv_input_size ,
4651            self .self_attn .head_dim ,
4752            self .self_attn .total_num_heads ,
4853            self .self_attn .total_num_kv_heads ,
@@ -52,6 +57,7 @@ def __init__(
5257        )
5358
5459        self .hidden_norm  =  RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
60+         self .layer_idx  =  layer_idx 
5561
5662        if  getattr (config , "norm_before_residual" , False ):
5763            self ._residual_norm  =  self ._norm_before_residual 
@@ -90,11 +96,15 @@ def forward(
9096        hidden_states : torch .Tensor ,
9197        residual : Optional [torch .Tensor ],
9298    ) ->  tuple [torch .Tensor , torch .Tensor ]:
93-         embeds  =  self .input_layernorm (embeds )
94- 
95-         hidden_states , residual  =  self ._residual_norm (hidden_states = hidden_states )
99+         if  self .layer_idx  ==  0 :
100+             # First layer: concatenate embeds with hidden_states 
101+             embeds  =  self .input_layernorm (embeds )
102+             hidden_states , residual  =  self ._residual_norm (hidden_states = hidden_states )
103+             hidden_states  =  torch .cat ([embeds , hidden_states ], dim = - 1 )
104+         else :
105+             # Subsequent layers: process hidden_states and residuals only 
106+             hidden_states , residual  =  self .input_layernorm (hidden_states , residual )
96107
97-         hidden_states  =  torch .cat ([embeds , hidden_states ], dim = - 1 )
98108        # Self Attention 
99109        hidden_states  =  self .self_attn (
100110            positions = positions ,
@@ -133,9 +143,11 @@ def __init__(
133143            [
134144                LlamaDecoderLayer (
135145                    current_vllm_config ,
136-                     prefix = maybe_prefix (prefix , f"layers.{ start_layer_id }  ),
146+                     prefix = maybe_prefix (prefix , f"layers.{ layer_idx   +   start_layer_id }  ),
137147                    config = self .config ,
148+                     layer_idx = layer_idx ,
138149                )
150+                 for  layer_idx  in  range (self .config .num_hidden_layers )
139151            ]
140152        )
141153        if  hasattr (self .config , "target_hidden_size" ):
@@ -166,13 +178,13 @@ def forward(
166178        assert  hidden_states .shape [- 1 ] ==  input_embeds .shape [- 1 ]
167179
168180        residual  =  None 
169-         hidden_states ,  residual   =  self .layers [ 0 ]( 
170-             positions , 
171-             input_embeds ,
172-             hidden_states ,
173-             residual ,
174-         ) 
175- 
181+         for   layer   in  self .layers : 
182+             hidden_states ,  residual   =   layer ( 
183+                  positions = positions ,
184+                  embeds = input_embeds ,
185+                  hidden_states = hidden_states ,
186+                  residual = residual , 
187+             ) 
176188        hidden_states , hidden_prenorm  =  self .norm (hidden_states , residual )
177189        return  hidden_states , hidden_prenorm 
178190
0 commit comments