4545from  transformers .video_utils  import  VideoMetadata 
4646
4747from  vllm .config  import  VllmConfig 
48- from  vllm .distributed  import  parallel_state 
48+ from  vllm .distributed  import  (get_tensor_model_parallel_world_size ,
49+                               parallel_state )
4950from  vllm .distributed  import  utils  as  dist_utils 
5051from  vllm .logger  import  init_logger 
5152from  vllm .model_executor  import  SamplingMetadata 
5253from  vllm .model_executor .layers .layernorm  import  RMSNorm 
54+ # yapf: disable 
5355from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
5456                                               MergedColumnParallelLinear ,
57+                                                MergedReplicatedLinear ,
5558                                               QKVParallelLinear ,
59+                                                ReplicatedLinear ,
5660                                               RowParallelLinear )
61+ # yapf: enable 
5762from  vllm .model_executor .layers .quantization  import  QuantizationConfig 
5863from  vllm .model_executor .model_loader .weight_utils  import  default_weight_loader 
5964from  vllm .model_executor .models .module_mapping  import  MultiModelKeys 
6671                                        BaseProcessingInfo , PromptReplacement ,
6772                                        PromptUpdate , PromptUpdateDetails )
6873from  vllm .multimodal .profiling  import  BaseDummyInputsBuilder 
74+ from  vllm .multimodal .utils  import  run_dp_sharded_mrope_vision_model 
6975from  vllm .platforms  import  _Backend 
7076from  vllm .sequence  import  IntermediateTensors 
7177from  vllm .transformers_utils .config  import  uses_mrope 
@@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
153159
154160Glm4vVideoInputs  =  Union [Glm4vVideoPixelInputs , Glm4vVideoEmbeddingInputs ]
155161
156- # === Vision Encoder === # 
162+ # ====  Vision Encoder = === # 
157163
158164
159165class  Glm4vVisionMLP (nn .Module ):
@@ -165,19 +171,23 @@ def __init__(
165171        bias : bool  =  False ,
166172        quant_config : Optional [QuantizationConfig ] =  None ,
167173        prefix : str  =  "" ,
174+         use_data_parallel : bool  =  False ,
168175    ):
169176        super ().__init__ ()
170-         self .gate_up_proj  =  MergedColumnParallelLinear (
171-             input_size = in_features ,
172-             output_sizes = [hidden_features ] *  2 ,
173-             bias = bias ,
174-             quant_config = quant_config ,
175-             prefix = f"{ prefix }  )
176-         self .down_proj  =  RowParallelLinear (hidden_features ,
177-                                            in_features ,
178-                                            bias = bias ,
179-                                            quant_config = quant_config ,
180-                                            prefix = f"{ prefix }  )
177+         cls_gate_up  =  (MergedReplicatedLinear 
178+                        if  use_data_parallel  else  MergedColumnParallelLinear )
179+         self .gate_up_proj  =  cls_gate_up (input_size = in_features ,
180+                                         output_sizes = [hidden_features ] *  2 ,
181+                                         bias = bias ,
182+                                         quant_config = quant_config ,
183+                                         prefix = f"{ prefix }  )
184+         cls_down  =  (ReplicatedLinear 
185+                     if  use_data_parallel  else  RowParallelLinear )
186+         self .down_proj  =  cls_down (hidden_features ,
187+                                   in_features ,
188+                                   bias = bias ,
189+                                   quant_config = quant_config ,
190+                                   prefix = f"{ prefix }  )
181191        self .act_fn  =  SiluAndMul ()
182192
183193    def  forward (self , x : torch .Tensor ):
@@ -218,33 +228,54 @@ def __init__(
218228        projection_size : int ,
219229        quant_config : Optional [QuantizationConfig ] =  None ,
220230        prefix : str  =  "" ,
231+         use_data_parallel : bool  =  False ,
221232    ) ->  None :
222233        super ().__init__ ()
223234        # Per attention head and per partition values. 
224-         self .tp_size  =  parallel_state .get_tensor_model_parallel_world_size ()
235+         self .tp_size  =  (1  if  use_data_parallel  else 
236+                         get_tensor_model_parallel_world_size ())
225237        self .tp_rank  =  parallel_state .get_tensor_model_parallel_rank ()
226238        self .hidden_size_per_attention_head  =  dist_utils .divide (
227239            projection_size , num_heads )
228240        self .num_attention_heads_per_partition  =  dist_utils .divide (
229241            num_heads , self .tp_size )
230242
231-         self .qkv  =  QKVParallelLinear (
232-             hidden_size = embed_dim ,
233-             head_size = self .hidden_size_per_attention_head ,
234-             total_num_heads = num_heads ,
235-             total_num_kv_heads = num_heads ,
236-             bias = False ,
237-             quant_config = quant_config ,
238-             # Change qkv prefix to align with GLM-4.5V-FP8 quantization config 
239-             prefix = f"{ prefix }   if  quant_config  else  f"{ prefix }  ,
240-         )
241-         self .proj  =  RowParallelLinear (
242-             input_size = projection_size ,
243-             output_size = embed_dim ,
244-             quant_config = quant_config ,
245-             prefix = f"{ prefix }  ,
246-             bias = False ,
247-         )
243+         if  use_data_parallel :
244+             self .qkv  =  ReplicatedLinear (
245+                 input_size = embed_dim ,
246+                 output_size = 3  *  projection_size ,
247+                 bias = False ,
248+                 quant_config = quant_config ,
249+                 # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg 
250+                 prefix = f"{ prefix }  
251+                 if  quant_config  else  f"{ prefix }  ,
252+             )
253+             self .proj  =  ReplicatedLinear (
254+                 input_size = projection_size ,
255+                 output_size = embed_dim ,
256+                 quant_config = quant_config ,
257+                 prefix = f"{ prefix }  ,
258+                 bias = False ,
259+             )
260+         else :
261+             self .qkv  =  QKVParallelLinear (
262+                 hidden_size = embed_dim ,
263+                 head_size = self .hidden_size_per_attention_head ,
264+                 total_num_heads = num_heads ,
265+                 total_num_kv_heads = num_heads ,
266+                 bias = False ,
267+                 quant_config = quant_config ,
268+                 # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg 
269+                 prefix = f"{ prefix }  
270+                 if  quant_config  else  f"{ prefix }  ,
271+             )
272+             self .proj  =  RowParallelLinear (
273+                 input_size = projection_size ,
274+                 output_size = embed_dim ,
275+                 quant_config = quant_config ,
276+                 prefix = f"{ prefix }  ,
277+                 bias = False ,
278+             )
248279
249280        # Detect attention implementation. 
250281        self .attn_backend : _Backend  =  get_vit_attn_backend (support_fa = True )
@@ -375,6 +406,7 @@ def __init__(
375406        norm_layer : Optional [Callable [[int ], nn .Module ]] =  None ,
376407        quant_config : Optional [QuantizationConfig ] =  None ,
377408        prefix : str  =  "" ,
409+         use_data_parallel : bool  =  False ,
378410    ) ->  None :
379411        super ().__init__ ()
380412        if  norm_layer  is  None :
@@ -387,13 +419,15 @@ def __init__(
387419            projection_size = dim ,
388420            quant_config = quant_config ,
389421            prefix = f"{ prefix }  ,
422+             use_data_parallel = use_data_parallel ,
390423        )
391424        self .mlp  =  Glm4vVisionMLP (
392425            dim ,
393426            mlp_hidden_dim ,
394427            bias = False ,
395428            quant_config = quant_config ,
396429            prefix = f"{ prefix }  ,
430+             use_data_parallel = use_data_parallel ,
397431        )
398432
399433    def  forward (
@@ -456,24 +490,40 @@ def __init__(
456490        quant_config : Optional [QuantizationConfig ] =  None ,
457491        bias : bool  =  False ,
458492        prefix : str  =  "" ,
493+         use_data_parallel : bool  =  False ,
459494    ) ->  None :
460495        super ().__init__ ()
461496        self .hidden_size  =  d_model 
462-         self .proj  =  ColumnParallelLinear (self .hidden_size ,
463-                                          self .hidden_size ,
464-                                          bias = bias ,
465-                                          gather_output = True ,
466-                                          quant_config = quant_config ,
467-                                          prefix = f"{ prefix }  )
497+         if  use_data_parallel :
498+             self .proj  =  ReplicatedLinear (
499+                 input_size = self .hidden_size ,
500+                 output_size = self .hidden_size ,
501+                 bias = bias ,
502+                 quant_config = quant_config ,
503+                 prefix = f"{ prefix }  ,
504+             )
505+         else :
506+             self .proj  =  ColumnParallelLinear (
507+                 self .hidden_size ,
508+                 self .hidden_size ,
509+                 bias = bias ,
510+                 gather_output = True ,
511+                 quant_config = quant_config ,
512+                 prefix = f"{ prefix }  ,
513+             )
468514        self .post_projection_norm  =  nn .LayerNorm (self .hidden_size )
469-         self .gate_up_proj  =  MergedColumnParallelLinear (
515+         cls_gate_up  =  (MergedReplicatedLinear 
516+                        if  use_data_parallel  else  MergedColumnParallelLinear )
517+         self .gate_up_proj  =  cls_gate_up (
470518            input_size = self .hidden_size ,
471519            output_sizes = [context_dim ] *  2 ,
472520            bias = bias ,
473521            quant_config = quant_config ,
474522            prefix = f"{ prefix }  ,
475523        )
476-         self .down_proj  =  RowParallelLinear (
524+         cls_down  =  (ReplicatedLinear 
525+                     if  use_data_parallel  else  RowParallelLinear )
526+         self .down_proj  =  cls_down (
477527            context_dim ,
478528            self .hidden_size ,
479529            bias = bias ,
@@ -548,14 +598,33 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
548598                                                        dtype = torch .float32 ))
549599
550600            # Calculate target dimensions for each patch 
551-             target_h  =  torch .cat ([
552-                 image_shapes [i , 1 ].repeat (lengths [i ])
553-                 for  i  in  range (len (lengths ))
554-             ]).to (device = device , dtype = torch .float32 )
555-             target_w  =  torch .cat ([
556-                 image_shapes [i , 2 ].repeat (lengths [i ])
557-                 for  i  in  range (len (lengths ))
558-             ]).to (device = device , dtype = torch .float32 )
601+             # Add bounds checking for data parallel mode 
602+             if  len (lengths ) >  image_shapes .shape [0 ]:
603+                 # In data parallel mode, some GPUs might not have all 
604+                 # image shapes 
605+                 # Use available image shapes, cycling if necessary 
606+                 target_h_list  =  []
607+                 target_w_list  =  []
608+                 for  i  in  range (len (lengths )):
609+                     # Cycle through available shapes 
610+                     shape_idx  =  i  %  image_shapes .shape [0 ]
611+                     target_h_list .append (image_shapes [shape_idx ,
612+                                                       1 ].repeat (lengths [i ]))
613+                     target_w_list .append (image_shapes [shape_idx ,
614+                                                       2 ].repeat (lengths [i ]))
615+                 target_h  =  torch .cat (target_h_list ).to (device = device ,
616+                                                        dtype = torch .float32 )
617+                 target_w  =  torch .cat (target_w_list ).to (device = device ,
618+                                                        dtype = torch .float32 )
619+             else :
620+                 target_h  =  torch .cat ([
621+                     image_shapes [i , 1 ].repeat (lengths [i ])
622+                     for  i  in  range (len (lengths ))
623+                 ]).to (device = device , dtype = torch .float32 )
624+                 target_w  =  torch .cat ([
625+                     image_shapes [i , 2 ].repeat (lengths [i ])
626+                     for  i  in  range (len (lengths ))
627+                 ]).to (device = device , dtype = torch .float32 )
559628
560629            # Normalize coordinates to [-1, 1] range for grid_sample 
561630            h_coords  =  h_coords .to (device = device , dtype = torch .float32 )
@@ -629,6 +698,7 @@ def __init__(
629698        norm_eps : float  =  1e-6 ,
630699        quant_config : Optional [QuantizationConfig ] =  None ,
631700        prefix : str  =  "" ,
701+         use_data_parallel : bool  =  False ,
632702    ) ->  None :
633703        super ().__init__ ()
634704
@@ -638,6 +708,7 @@ def __init__(
638708        depth  =  vision_config .depth 
639709        self .hidden_size  =  vision_config .hidden_size 
640710        self .num_heads  =  vision_config .num_heads 
711+         self .use_data_parallel  =  use_data_parallel 
641712
642713        self .patch_size  =  vision_config .patch_size 
643714        self .spatial_merge_size  =  vision_config .spatial_merge_size 
@@ -661,6 +732,7 @@ def __init__(
661732                norm_layer = norm_layer ,
662733                quant_config = quant_config ,
663734                prefix = f"{ prefix } { layer_idx }  ,
735+                 use_data_parallel = self .use_data_parallel ,
664736            ) for  layer_idx  in  range (depth )
665737        ])
666738        self .merger  =  Glm4vPatchMerger (
@@ -669,6 +741,7 @@ def __init__(
669741            quant_config = quant_config ,
670742            bias = False ,
671743            prefix = f"{ prefix }  ,
744+             use_data_parallel = self .use_data_parallel ,
672745        )
673746        self .embeddings  =  Glm4vVisionEmbeddings (vision_config )
674747
@@ -731,8 +804,11 @@ def compute_attn_mask_seqlen(
731804    def  forward (
732805        self ,
733806        x : torch .Tensor ,
734-         grid_thw : torch . Tensor ,
807+         grid_thw : list [ list [ int ]] ,
735808    ) ->  torch .Tensor :
809+         # Convert grid_thw to tensor (always expecting list format now) 
810+         grid_thw  =  torch .tensor (grid_thw , device = x .device , dtype = torch .long )
811+ 
736812        # patchify 
737813        x  =  x .to (device = self .device , dtype = self .dtype )
738814        x  =  self .patch_embed (x )
@@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
12501326            "model.visual." : "visual." ,
12511327        })
12521328
1329+     supports_encoder_tp_data  =  True 
1330+ 
12531331    @classmethod  
12541332    def  get_placeholder_str (cls , modality : str , i : int ) ->  Optional [str ]:
12551333        if  modality .startswith ("image" ):
@@ -1267,12 +1345,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12671345
12681346        self .config  =  config 
12691347        self .multimodal_config  =  multimodal_config 
1348+         self .use_data_parallel  =  multimodal_config .mm_encoder_tp_mode  ==  "data" 
12701349
12711350        self .visual  =  Glm4vVisionTransformer (
12721351            config .vision_config ,
12731352            norm_eps = getattr (config , "rms_norm_eps" , 1e-5 ),
12741353            quant_config = quant_config ,
12751354            prefix = maybe_prefix (prefix , "visual" ),
1355+             use_data_parallel = self .use_data_parallel ,
12761356        )
12771357
12781358        if  config .model_type  ==  "glm4v" :
@@ -1382,8 +1462,14 @@ def _process_image_input(
13821462            image_embeds  =  image_input ["image_embeds" ].type (self .visual .dtype )
13831463        else :
13841464            pixel_values  =  image_input ["pixel_values" ].type (self .visual .dtype )
1385-             image_embeds  =  self .visual (pixel_values , grid_thw = grid_thw )
1386- 
1465+             if  self .use_data_parallel :
1466+                 return  run_dp_sharded_mrope_vision_model (self .visual ,
1467+                                                          pixel_values ,
1468+                                                          grid_thw .tolist (),
1469+                                                          rope_type = "rope_3d" )
1470+             else :
1471+                 image_embeds  =  self .visual (pixel_values ,
1472+                                            grid_thw = grid_thw .tolist ())
13871473        merge_size  =  self .visual .spatial_merge_size 
13881474        sizes  =  grid_thw .prod (- 1 ) //  merge_size  //  merge_size 
13891475        return  image_embeds .split (sizes .tolist ())
@@ -1393,23 +1479,22 @@ def _process_video_input(
13931479        grid_thw  =  video_input ["video_grid_thw" ]
13941480        assert  grid_thw .ndim  ==  2 
13951481
1396-         device  =  self .visual .device 
1397-         flat_grid_thw  =  torch .cat ([
1398-             torch .tensor ([[1 , h , w ]] *  t , device = device )
1399-             for  t , h , w  in  grid_thw 
1400-         ])
14011482        if  video_input ["type" ] ==  "video_embeds" :
14021483            video_embeds  =  video_input ["video_embeds" ].type (self .visual .dtype )
14031484        else :
14041485            pixel_values_videos  =  video_input ["pixel_values_videos" ].type (
14051486                self .visual .dtype )
1406-             video_embeds  =  self .visual (pixel_values_videos ,
1407-                                        grid_thw = flat_grid_thw )
1408- 
1487+             if  self .use_data_parallel :
1488+                 return  run_dp_sharded_mrope_vision_model (self .visual ,
1489+                                                          pixel_values_videos ,
1490+                                                          grid_thw .tolist (),
1491+                                                          rope_type = "rope_3d" )
1492+             else :
1493+                 video_embeds  =  self .visual (pixel_values_videos ,
1494+                                            grid_thw = grid_thw .tolist ())
14091495        # Split concatenated embeddings for each video item. 
14101496        merge_size  =  self .visual .spatial_merge_size 
14111497        sizes  =  grid_thw .prod (- 1 ) //  merge_size  //  merge_size 
1412- 
14131498        return  video_embeds .split (sizes .tolist ())
14141499
14151500    def  _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) ->  dict :
0 commit comments