2121from  vllm .model_executor .layers .activation  import  get_act_fn 
2222from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
2323                                               QKVParallelLinear ,
24+                                                ReplicatedLinear ,
2425                                               RowParallelLinear )
2526from  vllm .model_executor .layers .quantization  import  QuantizationConfig 
2627from  vllm .model_executor .layers .sampler  import  SamplerOutput , get_sampler 
3334                                        BaseProcessingInfo , PromptReplacement ,
3435                                        PromptUpdate , PromptUpdateDetails )
3536from  vllm .multimodal .profiling  import  BaseDummyInputsBuilder 
37+ from  vllm .multimodal .utils  import  run_dp_sharded_vision_model 
3638from  vllm .sequence  import  IntermediateTensors 
3739from  vllm .transformers_utils .configs  import  Step3VisionEncoderConfig 
3840from  vllm .transformers_utils .tokenizer  import  AnyTokenizer 
@@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
650652    def  __init__ (self ,
651653                 config ,
652654                 quant_config : Optional [QuantizationConfig ] =  None ,
653-                  prefix : str  =  "" ):
655+                  prefix : str  =  "" ,
656+                  use_data_parallel : bool  =  False ):
654657        super ().__init__ ()
655658        self .config  =  config 
656659        self .embed_dim  =  config .hidden_size 
@@ -659,20 +662,42 @@ def __init__(self,
659662
660663        self .scale  =  self .head_dim ** - 0.5 
661664
662-         tp_size  =  get_tensor_model_parallel_world_size ()
665+         tp_size  =  (1  if  use_data_parallel  else 
666+                    get_tensor_model_parallel_world_size ())
663667        assert  self .total_num_heads  %  tp_size  ==  0 
664668        self .num_heads  =  self .total_num_heads  //  tp_size 
665-         self .qkv_proj  =  QKVParallelLinear (self .embed_dim ,
666-                                           self .head_dim ,
667-                                           self .total_num_heads ,
668-                                           bias = True ,
669-                                           quant_config = quant_config ,
670-                                           prefix = prefix )
671-         self .out_proj  =  RowParallelLinear (self .embed_dim ,
672-                                           self .embed_dim ,
673-                                           bias = True ,
674-                                           quant_config = quant_config ,
675-                                           prefix = prefix )
669+ 
670+         self .q_size  =  self .num_heads  *  self .head_dim 
671+ 
672+         if  use_data_parallel :
673+             self .qkv_proj  =  ReplicatedLinear (
674+                 self .embed_dim ,
675+                 3  *  self .q_size ,
676+                 bias = True ,
677+                 quant_config = quant_config ,
678+                 prefix = prefix ,
679+             )
680+             self .out_proj  =  ReplicatedLinear (
681+                 self .total_num_heads  *  self .head_dim ,
682+                 self .embed_dim ,
683+                 bias = True ,
684+                 quant_config = quant_config ,
685+                 prefix = prefix ,
686+             )
687+         else :
688+             self .qkv_proj  =  QKVParallelLinear (
689+                 self .embed_dim ,
690+                 self .head_dim ,
691+                 self .total_num_heads ,
692+                 bias = True ,
693+                 quant_config = quant_config ,
694+                 prefix = prefix ,
695+             )
696+             self .out_proj  =  RowParallelLinear (self .embed_dim ,
697+                                               self .embed_dim ,
698+                                               bias = True ,
699+                                               quant_config = quant_config ,
700+                                               prefix = prefix )
676701
677702    def  _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
678703        return  tensor .view (bsz , seq_len , self .num_heads ,
@@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
712737    def  __init__ (self ,
713738                 config ,
714739                 quant_config : Optional [QuantizationConfig ] =  None ,
715-                  prefix : str  =  "" ):
740+                  prefix : str  =  "" ,
741+                  use_data_parallel : bool  =  False ):
716742        super ().__init__ ()
717743        self .config  =  config 
718744        self .activation_fn  =  get_act_fn (config .hidden_act )
719-         self .fc1  =  ColumnParallelLinear (config .hidden_size ,
720-                                         config .intermediate_size ,
721-                                         bias = True ,
722-                                         quant_config = quant_config ,
723-                                         prefix = prefix )
724-         self .fc2  =  RowParallelLinear (config .intermediate_size ,
725-                                      config .hidden_size ,
726-                                      bias = True ,
727-                                      quant_config = quant_config ,
728-                                      prefix = prefix )
745+         cls_fc1  =  (ReplicatedLinear 
746+                    if  use_data_parallel  else  ColumnParallelLinear )
747+         self .fc1  =  cls_fc1 (config .hidden_size ,
748+                            config .intermediate_size ,
749+                            bias = True ,
750+                            quant_config = quant_config ,
751+                            prefix = prefix )
752+         cls_fc2  =  (ReplicatedLinear 
753+                    if  use_data_parallel  else  RowParallelLinear )
754+         self .fc2  =  cls_fc2 (config .intermediate_size ,
755+                            config .hidden_size ,
756+                            bias = True ,
757+                            quant_config = quant_config ,
758+                            prefix = prefix )
729759
730760    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
731761        hidden_states , _  =  self .fc1 (hidden_states )
@@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
739769    def  __init__ (self ,
740770                 config : Step3VisionEncoderConfig ,
741771                 quant_config : Optional [QuantizationConfig ] =  None ,
742-                  prefix : str  =  "" ):
772+                  prefix : str  =  "" ,
773+                  use_data_parallel : bool  =  False ):
743774        super ().__init__ ()
775+         self .use_data_parallel  =  use_data_parallel 
744776        self .embed_dim  =  config .hidden_size 
745-         self .self_attn  =  Step3VisionAttention (config ,
746-                                               quant_config ,
747-                                               prefix = f"{ prefix }  .self_attn" )
777+         self .self_attn  =  Step3VisionAttention (
778+             config ,
779+             quant_config ,
780+             prefix = f"{ prefix }  .self_attn" ,
781+             use_data_parallel = self .use_data_parallel )
748782        self .layer_norm1  =  nn .LayerNorm (self .embed_dim ,
749783                                        eps = config .layer_norm_eps )
750-         self .mlp  =  Step3VisionMLP (config , quant_config , prefix = f"{ prefix }  .mlp" )
784+         self .mlp  =  Step3VisionMLP (config ,
785+                                   quant_config ,
786+                                   prefix = f"{ prefix }  .mlp" ,
787+                                   use_data_parallel = self .use_data_parallel )
751788        self .layer_norm2  =  nn .LayerNorm (self .embed_dim ,
752789                                        eps = config .layer_norm_eps )
753790
@@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
767804    def  __init__ (self ,
768805                 config : Step3VisionEncoderConfig ,
769806                 quant_config : Optional [QuantizationConfig ] =  None ,
770-                  prefix : str  =  "" ):
807+                  prefix : str  =  "" ,
808+                  use_data_parallel : bool  =  False ):
771809        super ().__init__ ()
772810        self .config  =  config 
811+         self .use_data_parallel  =  use_data_parallel 
773812        self .layers  =  nn .ModuleList ([
774813            Step3VisionEncoderLayer (config ,
775814                                    quant_config ,
776-                                     prefix = f"{ prefix }  .layers.{ i }  " )
815+                                     prefix = f"{ prefix }  .layers.{ i }  " ,
816+                                     use_data_parallel = self .use_data_parallel )
777817            for  i  in  range (config .num_hidden_layers )
778818        ])
779819
@@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
792832    def  __init__ (self ,
793833                 config : Step3VisionEncoderConfig ,
794834                 quant_config : Optional [QuantizationConfig ] =  None ,
795-                  prefix : str  =  "" ):
835+                  prefix : str  =  "" ,
836+                  use_data_parallel : bool  =  False ):
796837        super ().__init__ ()
797838        self .config  =  config 
839+         self .use_data_parallel  =  use_data_parallel 
798840        self .image_size  =  config .image_size 
799841        self .embeddings  =  Step3VisionEmbeddings (config )
800-         self .transformer  =  Step3VisionEncoder (config ,
801-                                               quant_config ,
802-                                               prefix = f"{ prefix }  .transformer" )
842+         self .transformer  =  Step3VisionEncoder (
843+             config ,
844+             quant_config ,
845+             prefix = f"{ prefix }  .transformer" ,
846+             use_data_parallel = self .use_data_parallel )
803847
804848    def  forward (
805849        self ,
806850        pixel_values : torch .Tensor ,
807851    ):
808852        hidden_states  =  self .embeddings (pixel_values )
809-         hidden_states  =  self .transformer (inputs_embeds = hidden_states )
853+         if  self .use_data_parallel :
854+             hidden_states  =  run_dp_sharded_vision_model (
855+                 hidden_states , self .transformer )
856+         else :
857+             hidden_states  =  self .transformer (inputs_embeds = hidden_states )
810858        return  hidden_states 
811859
812860
@@ -836,13 +884,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
836884
837885        self .config  =  config 
838886        self .multimodal_config  =  multimodal_config 
887+         self .use_data_parallel  =  (vllm_config .parallel_config .
888+                                   enable_multimodal_encoder_data_parallel )
839889
840890        if  multimodal_config .get_limit_per_prompt ("image" ):
841-             self .vision_model  =  Step3VisionTransformer (config . vision_config , 
842-                                                         None ,
843-                                                         prefix = maybe_prefix ( 
844-                                                             prefix ,
845-                                                             "vision_model" ) )
891+             self .vision_model  =  Step3VisionTransformer (
892+                 config . vision_config ,
893+                 None , 
894+                 prefix = maybe_prefix ( prefix ,  "vision_model" ) ,
895+                 use_data_parallel = self . use_data_parallel )
846896            self .vit_downsampler  =  nn .Conv2d (
847897                config .vision_config .hidden_size ,
848898                config .vision_config .output_hidden_size ,
0 commit comments