2525from vllm .model_executor .layers .layernorm import RMSNorm
2626from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
2727 QKVParallelLinear ,
28+ ReplicatedLinear ,
2829 RowParallelLinear )
2930from vllm .model_executor .layers .quantization import QuantizationConfig
3031from vllm .model_executor .model_loader .weight_utils import default_weight_loader
32+ from vllm .multimodal .utils import run_dp_sharded_vision_model
3133
3234NORM2FN = {
3335 'rms_norm' : RMSNorm ,
@@ -137,6 +139,7 @@ def __init__(
137139 * ,
138140 num_dummy_heads : int = 0 ,
139141 prefix : str = "" ,
142+ use_data_parallel : bool = False ,
140143 ) -> None :
141144 super ().__init__ ()
142145
@@ -150,23 +153,34 @@ def __init__(
150153 f'(got `embed_dim`: { self .embed_dim } and `num_heads`:'
151154 f' { self .num_heads } ).' )
152155
153- self .tp_size = get_tensor_model_parallel_world_size ()
154- self .tp_rank = get_tensor_model_parallel_rank ()
156+ self .tp_size = (1 if use_data_parallel else
157+ get_tensor_model_parallel_world_size ())
158+ self .tp_rank = (0 if use_data_parallel else
159+ get_tensor_model_parallel_rank ())
155160
156161 # Additional dummy heads are used to enable TP for common GPU counts.
157162 self .dummy_dim = (num_dummy_heads + self .num_heads ) * self .head_dim
158163 self .num_heads_per_partition = divide (num_dummy_heads + self .num_heads ,
159164 self .tp_size )
160165
161166 self .scale = self .head_dim ** - 0.5
162- self .qkv = QKVParallelLinear (
163- self .embed_dim ,
164- self .head_dim ,
165- num_dummy_heads + self .num_heads ,
166- bias = config .qkv_bias ,
167- quant_config = quant_config ,
168- prefix = f"{ prefix } .qkv" ,
169- )
167+ if use_data_parallel :
168+ self .qkv = ReplicatedLinear (
169+ self .embed_dim ,
170+ 3 * self .head_dim * self .num_heads ,
171+ bias = config .qkv_bias ,
172+ quant_config = quant_config ,
173+ prefix = f"{ prefix } .qkv" ,
174+ )
175+ else :
176+ self .qkv = QKVParallelLinear (
177+ self .embed_dim ,
178+ self .head_dim ,
179+ num_dummy_heads + self .num_heads ,
180+ bias = config .qkv_bias ,
181+ quant_config = quant_config ,
182+ prefix = f"{ prefix } .qkv" ,
183+ )
170184
171185 self .qk_normalization = config .qk_normalization
172186
@@ -178,12 +192,20 @@ def __init__(
178192 eps = config .layer_norm_eps ,
179193 var_hidden_size = self .embed_dim )
180194
181- self .proj = RowParallelLinear (
182- self .dummy_dim ,
183- self .embed_dim ,
184- quant_config = quant_config ,
185- prefix = f"{ prefix } .proj" ,
186- )
195+ if use_data_parallel :
196+ self .proj = ReplicatedLinear (
197+ self .dummy_dim ,
198+ self .embed_dim ,
199+ quant_config = quant_config ,
200+ prefix = f"{ prefix } .proj" ,
201+ )
202+ else :
203+ self .proj = RowParallelLinear (
204+ self .dummy_dim ,
205+ self .embed_dim ,
206+ quant_config = quant_config ,
207+ prefix = f"{ prefix } .proj" ,
208+ )
187209
188210 self .attn = MultiHeadAttention (self .num_heads_per_partition ,
189211 self .head_dim , self .scale )
@@ -287,21 +309,26 @@ def __init__(
287309 config : PretrainedConfig ,
288310 quant_config : Optional [QuantizationConfig ] = None ,
289311 prefix : str = "" ,
312+ use_data_parallel : bool = False ,
290313 ) -> None :
291314 super ().__init__ ()
292315
293316 self .config = config
294317 self .activation_fn = get_act_fn (config .hidden_act )
295- self .fc1 = ColumnParallelLinear (config .hidden_size ,
296- config .intermediate_size ,
297- bias = True ,
298- quant_config = quant_config ,
299- prefix = f"{ prefix } .fc1" )
300- self .fc2 = RowParallelLinear (config .intermediate_size ,
301- config .hidden_size ,
302- bias = True ,
303- quant_config = quant_config ,
304- prefix = f"{ prefix } .fc2" )
318+ cls_fc1 = (ReplicatedLinear
319+ if use_data_parallel else ColumnParallelLinear )
320+ self .fc1 = cls_fc1 (config .hidden_size ,
321+ config .intermediate_size ,
322+ bias = True ,
323+ quant_config = quant_config ,
324+ prefix = f"{ prefix } .fc1" )
325+ cls_fc2 = (ReplicatedLinear
326+ if use_data_parallel else RowParallelLinear )
327+ self .fc2 = cls_fc2 (config .intermediate_size ,
328+ config .hidden_size ,
329+ bias = True ,
330+ quant_config = quant_config ,
331+ prefix = f"{ prefix } .fc2" )
305332
306333 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
307334 hidden_states , _ = self .fc1 (hidden_states )
@@ -320,6 +347,7 @@ def __init__(
320347 * ,
321348 num_dummy_heads : int = 0 ,
322349 prefix : str = "" ,
350+ use_data_parallel : bool = False ,
323351 ) -> None :
324352 super ().__init__ ()
325353
@@ -330,11 +358,13 @@ def __init__(
330358 self .attn = self ._init_attn (config ,
331359 quant_config ,
332360 num_dummy_heads = num_dummy_heads ,
333- prefix = f"{ prefix } .attn" )
361+ prefix = f"{ prefix } .attn" ,
362+ use_data_parallel = use_data_parallel )
334363
335364 self .mlp = InternMLP (config ,
336365 quant_config = quant_config ,
337- prefix = f"{ prefix } .mlp" )
366+ prefix = f"{ prefix } .mlp" ,
367+ use_data_parallel = use_data_parallel )
338368 self .norm1 = NORM2FN [self .norm_type ](self .embed_dim ,
339369 eps = config .layer_norm_eps )
340370 self .norm2 = NORM2FN [self .norm_type ](self .embed_dim ,
@@ -352,16 +382,20 @@ def _init_attn(
352382 * ,
353383 num_dummy_heads : int ,
354384 prefix : str = "" ,
385+ use_data_parallel : bool = False ,
355386 ):
356387 # fallback to sdpa attention if tp unavailable
357- tp_size = get_tensor_model_parallel_world_size ()
388+ # tp_size = get_tensor_model_parallel_world_size()
389+ tp_size = (1 if use_data_parallel else
390+ get_tensor_model_parallel_world_size ())
358391 num_heads = config .num_attention_heads
359392
360393 if (num_heads + num_dummy_heads ) % tp_size == 0 :
361394 return InternParallelAttention (config ,
362395 quant_config = quant_config ,
363396 num_dummy_heads = num_dummy_heads ,
364- prefix = prefix )
397+ prefix = prefix ,
398+ use_data_parallel = use_data_parallel )
365399
366400 return InternSdpaAttention (config , num_dummy_heads = num_dummy_heads )
367401
@@ -388,6 +422,7 @@ def __init__(
388422 num_hidden_layers_override : Optional [int ] = None ,
389423 num_dummy_heads : int = 0 ,
390424 prefix : str = "" ,
425+ use_data_parallel : bool = False ,
391426 ):
392427 super ().__init__ ()
393428
@@ -402,7 +437,8 @@ def __init__(
402437 InternVisionEncoderLayer (config ,
403438 quant_config ,
404439 num_dummy_heads = num_dummy_heads ,
405- prefix = f"{ prefix } .layers.{ layer_idx } " )
440+ prefix = f"{ prefix } .layers.{ layer_idx } " ,
441+ use_data_parallel = use_data_parallel )
406442 for layer_idx in range (num_hidden_layers )
407443 ])
408444
@@ -429,10 +465,12 @@ def __init__(
429465 num_hidden_layers_override : Optional [int ] = None ,
430466 num_dummy_heads : int = 0 ,
431467 prefix : str = "" ,
468+ use_data_parallel : bool = False ,
432469 ) -> None :
433470 super ().__init__ ()
434471
435472 self .config = config
473+ self .use_data_parallel = use_data_parallel
436474
437475 self .embeddings = InternVisionEmbeddings (config )
438476 self .encoder = InternVisionEncoder (
@@ -441,6 +479,7 @@ def __init__(
441479 num_hidden_layers_override = num_hidden_layers_override ,
442480 num_dummy_heads = num_dummy_heads ,
443481 prefix = f"{ prefix } .encoder" ,
482+ use_data_parallel = use_data_parallel ,
444483 )
445484
446485 def get_input_embeddings (self ):
@@ -464,7 +503,11 @@ def forward(
464503 raise ValueError (
465504 f'wrong pixel_values size: { pixel_values .shape } ' )
466505
467- encoder_outputs = self .encoder (inputs_embeds = hidden_states )
506+ if self .use_data_parallel :
507+ encoder_outputs = run_dp_sharded_vision_model (
508+ hidden_states , self .encoder )
509+ else :
510+ encoder_outputs = self .encoder (inputs_embeds = hidden_states )
468511
469512 return encoder_outputs
470513
0 commit comments