Skip to content

Commit 2f0bab3

Browse files
authored
[Model] Support dp on ViT on GLM-4.5V (#23168)
Signed-off-by: David Chen <530634352@qq.com>
1 parent fad73be commit 2f0bab3

File tree

2 files changed

+145
-59
lines changed

2 files changed

+145
-59
lines changed

docs/configuration/optimization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
174174

175175
Known supported models:
176176

177+
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
177178
- Kimi-VL (<gh-pr:23817>)
178179
- Llama4 (<gh-pr:18368>)
179180
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)

vllm/model_executor/models/glm4_1v.py

Lines changed: 144 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,20 @@
4545
from transformers.video_utils import VideoMetadata
4646

4747
from vllm.config import VllmConfig
48-
from vllm.distributed import parallel_state
48+
from vllm.distributed import (get_tensor_model_parallel_world_size,
49+
parallel_state)
4950
from vllm.distributed import utils as dist_utils
5051
from vllm.logger import init_logger
5152
from vllm.model_executor import SamplingMetadata
5253
from vllm.model_executor.layers.layernorm import RMSNorm
54+
# yapf: disable
5355
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5456
MergedColumnParallelLinear,
57+
MergedReplicatedLinear,
5558
QKVParallelLinear,
59+
ReplicatedLinear,
5660
RowParallelLinear)
61+
# yapf: enable
5762
from vllm.model_executor.layers.quantization import QuantizationConfig
5863
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
5964
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -66,6 +71,7 @@
6671
BaseProcessingInfo, PromptReplacement,
6772
PromptUpdate, PromptUpdateDetails)
6873
from vllm.multimodal.profiling import BaseDummyInputsBuilder
74+
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
6975
from vllm.platforms import _Backend
7076
from vllm.sequence import IntermediateTensors
7177
from vllm.transformers_utils.config import uses_mrope
@@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
153159

154160
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
155161

156-
# === Vision Encoder === #
162+
# ==== Vision Encoder ==== #
157163

158164

159165
class Glm4vVisionMLP(nn.Module):
@@ -165,19 +171,23 @@ def __init__(
165171
bias: bool = False,
166172
quant_config: Optional[QuantizationConfig] = None,
167173
prefix: str = "",
174+
use_data_parallel: bool = False,
168175
):
169176
super().__init__()
170-
self.gate_up_proj = MergedColumnParallelLinear(
171-
input_size=in_features,
172-
output_sizes=[hidden_features] * 2,
173-
bias=bias,
174-
quant_config=quant_config,
175-
prefix=f"{prefix}.gate_up_proj")
176-
self.down_proj = RowParallelLinear(hidden_features,
177-
in_features,
178-
bias=bias,
179-
quant_config=quant_config,
180-
prefix=f"{prefix}.down_proj")
177+
cls_gate_up = (MergedReplicatedLinear
178+
if use_data_parallel else MergedColumnParallelLinear)
179+
self.gate_up_proj = cls_gate_up(input_size=in_features,
180+
output_sizes=[hidden_features] * 2,
181+
bias=bias,
182+
quant_config=quant_config,
183+
prefix=f"{prefix}.gate_up_proj")
184+
cls_down = (ReplicatedLinear
185+
if use_data_parallel else RowParallelLinear)
186+
self.down_proj = cls_down(hidden_features,
187+
in_features,
188+
bias=bias,
189+
quant_config=quant_config,
190+
prefix=f"{prefix}.down_proj")
181191
self.act_fn = SiluAndMul()
182192

183193
def forward(self, x: torch.Tensor):
@@ -218,33 +228,54 @@ def __init__(
218228
projection_size: int,
219229
quant_config: Optional[QuantizationConfig] = None,
220230
prefix: str = "",
231+
use_data_parallel: bool = False,
221232
) -> None:
222233
super().__init__()
223234
# Per attention head and per partition values.
224-
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
235+
self.tp_size = (1 if use_data_parallel else
236+
get_tensor_model_parallel_world_size())
225237
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
226238
self.hidden_size_per_attention_head = dist_utils.divide(
227239
projection_size, num_heads)
228240
self.num_attention_heads_per_partition = dist_utils.divide(
229241
num_heads, self.tp_size)
230242

231-
self.qkv = QKVParallelLinear(
232-
hidden_size=embed_dim,
233-
head_size=self.hidden_size_per_attention_head,
234-
total_num_heads=num_heads,
235-
total_num_kv_heads=num_heads,
236-
bias=False,
237-
quant_config=quant_config,
238-
# Change qkv prefix to align with GLM-4.5V-FP8 quantization config
239-
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
240-
)
241-
self.proj = RowParallelLinear(
242-
input_size=projection_size,
243-
output_size=embed_dim,
244-
quant_config=quant_config,
245-
prefix=f"{prefix}.proj",
246-
bias=False,
247-
)
243+
if use_data_parallel:
244+
self.qkv = ReplicatedLinear(
245+
input_size=embed_dim,
246+
output_size=3 * projection_size,
247+
bias=False,
248+
quant_config=quant_config,
249+
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
250+
prefix=f"{prefix}.qkv_proj"
251+
if quant_config else f"{prefix}.qkv",
252+
)
253+
self.proj = ReplicatedLinear(
254+
input_size=projection_size,
255+
output_size=embed_dim,
256+
quant_config=quant_config,
257+
prefix=f"{prefix}.proj",
258+
bias=False,
259+
)
260+
else:
261+
self.qkv = QKVParallelLinear(
262+
hidden_size=embed_dim,
263+
head_size=self.hidden_size_per_attention_head,
264+
total_num_heads=num_heads,
265+
total_num_kv_heads=num_heads,
266+
bias=False,
267+
quant_config=quant_config,
268+
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
269+
prefix=f"{prefix}.qkv_proj"
270+
if quant_config else f"{prefix}.qkv",
271+
)
272+
self.proj = RowParallelLinear(
273+
input_size=projection_size,
274+
output_size=embed_dim,
275+
quant_config=quant_config,
276+
prefix=f"{prefix}.proj",
277+
bias=False,
278+
)
248279

249280
# Detect attention implementation.
250281
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -375,6 +406,7 @@ def __init__(
375406
norm_layer: Optional[Callable[[int], nn.Module]] = None,
376407
quant_config: Optional[QuantizationConfig] = None,
377408
prefix: str = "",
409+
use_data_parallel: bool = False,
378410
) -> None:
379411
super().__init__()
380412
if norm_layer is None:
@@ -387,13 +419,15 @@ def __init__(
387419
projection_size=dim,
388420
quant_config=quant_config,
389421
prefix=f"{prefix}.attn",
422+
use_data_parallel=use_data_parallel,
390423
)
391424
self.mlp = Glm4vVisionMLP(
392425
dim,
393426
mlp_hidden_dim,
394427
bias=False,
395428
quant_config=quant_config,
396429
prefix=f"{prefix}.mlp",
430+
use_data_parallel=use_data_parallel,
397431
)
398432

399433
def forward(
@@ -456,24 +490,40 @@ def __init__(
456490
quant_config: Optional[QuantizationConfig] = None,
457491
bias: bool = False,
458492
prefix: str = "",
493+
use_data_parallel: bool = False,
459494
) -> None:
460495
super().__init__()
461496
self.hidden_size = d_model
462-
self.proj = ColumnParallelLinear(self.hidden_size,
463-
self.hidden_size,
464-
bias=bias,
465-
gather_output=True,
466-
quant_config=quant_config,
467-
prefix=f"{prefix}.proj")
497+
if use_data_parallel:
498+
self.proj = ReplicatedLinear(
499+
input_size=self.hidden_size,
500+
output_size=self.hidden_size,
501+
bias=bias,
502+
quant_config=quant_config,
503+
prefix=f"{prefix}.proj",
504+
)
505+
else:
506+
self.proj = ColumnParallelLinear(
507+
self.hidden_size,
508+
self.hidden_size,
509+
bias=bias,
510+
gather_output=True,
511+
quant_config=quant_config,
512+
prefix=f"{prefix}.proj",
513+
)
468514
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
469-
self.gate_up_proj = MergedColumnParallelLinear(
515+
cls_gate_up = (MergedReplicatedLinear
516+
if use_data_parallel else MergedColumnParallelLinear)
517+
self.gate_up_proj = cls_gate_up(
470518
input_size=self.hidden_size,
471519
output_sizes=[context_dim] * 2,
472520
bias=bias,
473521
quant_config=quant_config,
474522
prefix=f"{prefix}.gate_up_proj",
475523
)
476-
self.down_proj = RowParallelLinear(
524+
cls_down = (ReplicatedLinear
525+
if use_data_parallel else RowParallelLinear)
526+
self.down_proj = cls_down(
477527
context_dim,
478528
self.hidden_size,
479529
bias=bias,
@@ -548,14 +598,33 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
548598
dtype=torch.float32))
549599

550600
# Calculate target dimensions for each patch
551-
target_h = torch.cat([
552-
image_shapes[i, 1].repeat(lengths[i])
553-
for i in range(len(lengths))
554-
]).to(device=device, dtype=torch.float32)
555-
target_w = torch.cat([
556-
image_shapes[i, 2].repeat(lengths[i])
557-
for i in range(len(lengths))
558-
]).to(device=device, dtype=torch.float32)
601+
# Add bounds checking for data parallel mode
602+
if len(lengths) > image_shapes.shape[0]:
603+
# In data parallel mode, some GPUs might not have all
604+
# image shapes
605+
# Use available image shapes, cycling if necessary
606+
target_h_list = []
607+
target_w_list = []
608+
for i in range(len(lengths)):
609+
# Cycle through available shapes
610+
shape_idx = i % image_shapes.shape[0]
611+
target_h_list.append(image_shapes[shape_idx,
612+
1].repeat(lengths[i]))
613+
target_w_list.append(image_shapes[shape_idx,
614+
2].repeat(lengths[i]))
615+
target_h = torch.cat(target_h_list).to(device=device,
616+
dtype=torch.float32)
617+
target_w = torch.cat(target_w_list).to(device=device,
618+
dtype=torch.float32)
619+
else:
620+
target_h = torch.cat([
621+
image_shapes[i, 1].repeat(lengths[i])
622+
for i in range(len(lengths))
623+
]).to(device=device, dtype=torch.float32)
624+
target_w = torch.cat([
625+
image_shapes[i, 2].repeat(lengths[i])
626+
for i in range(len(lengths))
627+
]).to(device=device, dtype=torch.float32)
559628

560629
# Normalize coordinates to [-1, 1] range for grid_sample
561630
h_coords = h_coords.to(device=device, dtype=torch.float32)
@@ -629,6 +698,7 @@ def __init__(
629698
norm_eps: float = 1e-6,
630699
quant_config: Optional[QuantizationConfig] = None,
631700
prefix: str = "",
701+
use_data_parallel: bool = False,
632702
) -> None:
633703
super().__init__()
634704

@@ -638,6 +708,7 @@ def __init__(
638708
depth = vision_config.depth
639709
self.hidden_size = vision_config.hidden_size
640710
self.num_heads = vision_config.num_heads
711+
self.use_data_parallel = use_data_parallel
641712

642713
self.patch_size = vision_config.patch_size
643714
self.spatial_merge_size = vision_config.spatial_merge_size
@@ -661,6 +732,7 @@ def __init__(
661732
norm_layer=norm_layer,
662733
quant_config=quant_config,
663734
prefix=f"{prefix}.blocks.{layer_idx}",
735+
use_data_parallel=self.use_data_parallel,
664736
) for layer_idx in range(depth)
665737
])
666738
self.merger = Glm4vPatchMerger(
@@ -669,6 +741,7 @@ def __init__(
669741
quant_config=quant_config,
670742
bias=False,
671743
prefix=f"{prefix}.merger",
744+
use_data_parallel=self.use_data_parallel,
672745
)
673746
self.embeddings = Glm4vVisionEmbeddings(vision_config)
674747

@@ -731,8 +804,11 @@ def compute_attn_mask_seqlen(
731804
def forward(
732805
self,
733806
x: torch.Tensor,
734-
grid_thw: torch.Tensor,
807+
grid_thw: list[list[int]],
735808
) -> torch.Tensor:
809+
# Convert grid_thw to tensor (always expecting list format now)
810+
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
811+
736812
# patchify
737813
x = x.to(device=self.device, dtype=self.dtype)
738814
x = self.patch_embed(x)
@@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
12501326
"model.visual.": "visual.",
12511327
})
12521328

1329+
supports_encoder_tp_data = True
1330+
12531331
@classmethod
12541332
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
12551333
if modality.startswith("image"):
@@ -1267,12 +1345,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12671345

12681346
self.config = config
12691347
self.multimodal_config = multimodal_config
1348+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
12701349

12711350
self.visual = Glm4vVisionTransformer(
12721351
config.vision_config,
12731352
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
12741353
quant_config=quant_config,
12751354
prefix=maybe_prefix(prefix, "visual"),
1355+
use_data_parallel=self.use_data_parallel,
12761356
)
12771357

12781358
if config.model_type == "glm4v":
@@ -1382,8 +1462,14 @@ def _process_image_input(
13821462
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
13831463
else:
13841464
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1385-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1386-
1465+
if self.use_data_parallel:
1466+
return run_dp_sharded_mrope_vision_model(self.visual,
1467+
pixel_values,
1468+
grid_thw.tolist(),
1469+
rope_type="rope_3d")
1470+
else:
1471+
image_embeds = self.visual(pixel_values,
1472+
grid_thw=grid_thw.tolist())
13871473
merge_size = self.visual.spatial_merge_size
13881474
sizes = grid_thw.prod(-1) // merge_size // merge_size
13891475
return image_embeds.split(sizes.tolist())
@@ -1393,23 +1479,22 @@ def _process_video_input(
13931479
grid_thw = video_input["video_grid_thw"]
13941480
assert grid_thw.ndim == 2
13951481

1396-
device = self.visual.device
1397-
flat_grid_thw = torch.cat([
1398-
torch.tensor([[1, h, w]] * t, device=device)
1399-
for t, h, w in grid_thw
1400-
])
14011482
if video_input["type"] == "video_embeds":
14021483
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
14031484
else:
14041485
pixel_values_videos = video_input["pixel_values_videos"].type(
14051486
self.visual.dtype)
1406-
video_embeds = self.visual(pixel_values_videos,
1407-
grid_thw=flat_grid_thw)
1408-
1487+
if self.use_data_parallel:
1488+
return run_dp_sharded_mrope_vision_model(self.visual,
1489+
pixel_values_videos,
1490+
grid_thw.tolist(),
1491+
rope_type="rope_3d")
1492+
else:
1493+
video_embeds = self.visual(pixel_values_videos,
1494+
grid_thw=grid_thw.tolist())
14091495
# Split concatenated embeddings for each video item.
14101496
merge_size = self.visual.spatial_merge_size
14111497
sizes = grid_thw.prod(-1) // merge_size // merge_size
1412-
14131498
return video_embeds.split(sizes.tolist())
14141499

14151500
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

0 commit comments

Comments
 (0)