2727 Idefics2Config , Idefics2VisionConfig )
2828
2929from vllm .attention .layer import MultiHeadAttention
30- from vllm .distributed import divide , get_tensor_model_parallel_world_size
30+ from vllm .distributed import get_tensor_model_parallel_world_size
3131from vllm .model_executor .layers .activation import get_act_fn
3232from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
3333 QKVParallelLinear ,
34+ ReplicatedLinear ,
3435 RowParallelLinear )
3536from vllm .model_executor .layers .quantization import QuantizationConfig
3637from vllm .model_executor .model_loader .weight_utils import default_weight_loader
38+ from vllm .multimodal .utils import run_dp_sharded_vision_model
3739
3840
3941class Idefics2VisionEmbeddings (nn .Module ):
@@ -118,6 +120,7 @@ def __init__(
118120 config : Idefics2VisionConfig ,
119121 quant_config : Optional [QuantizationConfig ] = None ,
120122 prefix : str = "" ,
123+ use_data_parallel : bool = False ,
121124 ) -> None :
122125 super ().__init__ ()
123126 self .config = config
@@ -130,22 +133,43 @@ def __init__(
130133 f" { self .num_heads } )." )
131134 self .scale = self .head_dim ** - 0.5
132135 self .dropout = config .attention_dropout
133- self .qkv_proj = QKVParallelLinear (
134- self .embed_dim ,
135- self .head_dim ,
136- self .num_heads ,
137- quant_config = quant_config ,
138- prefix = f"{ prefix } .qkv_proj" ,
139- )
140- self .out_proj = RowParallelLinear (
141- self .embed_dim ,
142- self .embed_dim ,
143- bias = True ,
144- quant_config = quant_config ,
145- prefix = f"{ prefix } .out_proj" ,
146- )
147- self .tp_size = get_tensor_model_parallel_world_size ()
148- self .num_heads_per_partition = divide (self .num_heads , self .tp_size )
136+
137+ tp_size = (1 if use_data_parallel else
138+ get_tensor_model_parallel_world_size ())
139+ assert self .num_heads % tp_size == 0
140+ self .num_heads_per_partition = self .num_heads // tp_size
141+
142+ if use_data_parallel :
143+ self .q_size = self .num_heads * self .head_dim
144+ self .qkv_proj = ReplicatedLinear (
145+ self .embed_dim ,
146+ 3 * self .q_size ,
147+ bias = True ,
148+ quant_config = quant_config ,
149+ prefix = f"{ prefix } .qkv_proj" ,
150+ )
151+ self .out_proj = ReplicatedLinear (
152+ self .embed_dim ,
153+ self .embed_dim ,
154+ bias = True ,
155+ quant_config = quant_config ,
156+ prefix = f"{ prefix } .out_proj" ,
157+ )
158+ else :
159+ self .qkv_proj = QKVParallelLinear (
160+ self .embed_dim ,
161+ self .head_dim ,
162+ self .num_heads ,
163+ quant_config = quant_config ,
164+ prefix = f"{ prefix } .qkv_proj" ,
165+ )
166+ self .out_proj = RowParallelLinear (
167+ self .embed_dim ,
168+ self .embed_dim ,
169+ bias = True ,
170+ quant_config = quant_config ,
171+ prefix = f"{ prefix } .out_proj" ,
172+ )
149173 self .attn = MultiHeadAttention (self .num_heads_per_partition ,
150174 self .head_dim , self .scale )
151175
@@ -169,18 +193,23 @@ def __init__(
169193 config : Idefics2VisionConfig ,
170194 quant_config : Optional [QuantizationConfig ] = None ,
171195 prefix : str = "" ,
196+ use_data_parallel : bool = False ,
172197 ) -> None :
173198 super ().__init__ ()
174199 self .config = config
175200 self .activation_fn = get_act_fn (config .hidden_act )
176- self .fc1 = ColumnParallelLinear (
201+ cls_fc1 = (ReplicatedLinear
202+ if use_data_parallel else ColumnParallelLinear )
203+ self .fc1 = cls_fc1 (
177204 config .hidden_size ,
178205 config .intermediate_size ,
179206 bias = True ,
180207 quant_config = quant_config ,
181208 prefix = f"{ prefix } .fc1" ,
182209 )
183- self .fc2 = RowParallelLinear (
210+ cls_fc2 = (ReplicatedLinear
211+ if use_data_parallel else RowParallelLinear )
212+ self .fc2 = cls_fc2 (
184213 config .intermediate_size ,
185214 config .hidden_size ,
186215 bias = True ,
@@ -202,17 +231,21 @@ def __init__(
202231 config : Idefics2Config ,
203232 quant_config : Optional [QuantizationConfig ] = None ,
204233 prefix : str = "" ,
234+ use_data_parallel : bool = False ,
205235 ) -> None :
206236 super ().__init__ ()
207237 self .embed_dim = config .hidden_size
208- self .self_attn = Idefics2VisionAttention (config ,
209- quant_config = quant_config ,
210- prefix = f"{ prefix } .self_attn" )
238+ self .self_attn = Idefics2VisionAttention (
239+ config ,
240+ quant_config = quant_config ,
241+ prefix = f"{ prefix } .self_attn" ,
242+ use_data_parallel = use_data_parallel )
211243 self .layer_norm1 = nn .LayerNorm (self .embed_dim ,
212244 eps = config .layer_norm_eps )
213245 self .mlp = Idefics2VisionMLP (config ,
214246 quant_config = quant_config ,
215- prefix = f"{ prefix } .mlp" )
247+ prefix = f"{ prefix } .mlp" ,
248+ use_data_parallel = use_data_parallel )
216249 self .layer_norm2 = nn .LayerNorm (self .embed_dim ,
217250 eps = config .layer_norm_eps )
218251
@@ -254,6 +287,7 @@ def __init__(
254287 * ,
255288 num_hidden_layers_override : Optional [int ] = None ,
256289 prefix : str = "" ,
290+ use_data_parallel : bool = False ,
257291 ) -> None :
258292 super ().__init__ ()
259293
@@ -267,7 +301,8 @@ def __init__(
267301 self .layers = nn .ModuleList ([
268302 Idefics2EncoderLayer (config ,
269303 quant_config = quant_config ,
270- prefix = f"{ prefix } .layers.{ layer_idx } " )
304+ prefix = f"{ prefix } .layers.{ layer_idx } " ,
305+ use_data_parallel = use_data_parallel )
271306 for layer_idx in range (num_hidden_layers )
272307 ])
273308
@@ -301,17 +336,20 @@ def __init__(
301336 num_hidden_layers_override : Optional [int ] = None ,
302337 require_post_norm : bool = True ,
303338 prefix : str = "" ,
339+ use_data_parallel : bool = False ,
304340 ) -> None :
305341 super ().__init__ ()
306342
307343 embed_dim = config .hidden_size
308344 self .config = config
345+ self .use_data_parallel = use_data_parallel
309346 self .embeddings = Idefics2VisionEmbeddings (config )
310347 self .encoder = Idefics2Encoder (
311348 config ,
312349 quant_config = quant_config ,
313350 num_hidden_layers_override = num_hidden_layers_override ,
314- prefix = f"{ prefix } .encoder" )
351+ prefix = f"{ prefix } .encoder" ,
352+ use_data_parallel = use_data_parallel )
315353
316354 num_hidden_layers = config .num_hidden_layers
317355 if len (self .encoder .layers ) > config .num_hidden_layers :
@@ -340,10 +378,38 @@ def forward(
340378 patch_attention_mask = patch_attention_mask ,
341379 tgt_sizes = tgt_sizes ,
342380 )
343- encoder_outputs = self .encoder (hidden_states )
381+ if self .use_data_parallel :
382+ encoder_outputs = run_dp_sharded_vision_model (
383+ hidden_states , self .encoder )
384+ else :
385+ encoder_outputs = self .encoder (hidden_states )
344386 last_hidden_state = self .post_layernorm (encoder_outputs )
345387 return last_hidden_state
346388
389+ def _consolidate_qkv_weights (
390+ self , weights : Iterable [tuple [str , torch .Tensor ]]
391+ ) -> Iterable [tuple [str , torch .Tensor ]]:
392+ qkv_idx_mappings = {
393+ ".self_attn.q_proj" : 0 ,
394+ ".self_attn.k_proj" : 1 ,
395+ ".self_attn.v_proj" : 2 ,
396+ }
397+ qkv_weights = {}
398+ for name , loaded_weight in weights :
399+ for weight_name , idx in qkv_idx_mappings .items ():
400+ if weight_name not in name :
401+ continue
402+ new_name = name .replace (weight_name , ".self_attn.qkv_proj" )
403+ if new_name not in qkv_weights :
404+ qkv_weights [new_name ] = [None ] * 3
405+ qkv_weights [new_name ][idx ] = loaded_weight
406+ break
407+ else :
408+ yield name , loaded_weight
409+ for key , weight in qkv_weights .items ():
410+ qkv_weight = torch .cat (weight , dim = 0 )
411+ yield key , qkv_weight
412+
347413 def load_weights (self , weights : Iterable [tuple [str ,
348414 torch .Tensor ]]) -> set [str ]:
349415 stacked_params_mapping = [
@@ -356,6 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str,
356422 loaded_params : set [str ] = set ()
357423 layer_count = len (self .encoder .layers )
358424
425+ if self .use_data_parallel :
426+ weights = self ._consolidate_qkv_weights (weights )
427+
359428 for name , loaded_weight in weights :
360429 # skip pooling header
361430 if name .startswith ("head." ):
@@ -373,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str,
373442 continue
374443
375444 for param_name , weight_name , shard_id in stacked_params_mapping :
376- if weight_name not in name :
445+ if weight_name not in name or self . use_data_parallel :
377446 continue
378447 name = name .replace (weight_name , param_name )
379448 param = params_dict [name ]
0 commit comments