Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configuration/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u

Known supported models:

- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
- Kimi-VL (<gh-pr:23817>)
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
Expand Down
203 changes: 144 additions & 59 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@
from transformers.video_utils import VideoMetadata

from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand All @@ -66,6 +71,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
Expand Down Expand Up @@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):

Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]

# === Vision Encoder === #
# ==== Vision Encoder ==== #


class Glm4vVisionMLP(nn.Module):
Expand All @@ -165,19 +171,23 @@ def __init__(
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = SiluAndMul()

def forward(self, x: torch.Tensor):
Expand Down Expand Up @@ -218,33 +228,54 @@ def __init__(
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)

self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization config
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
if use_data_parallel:
self.qkv = ReplicatedLinear(
input_size=embed_dim,
output_size=3 * projection_size,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = ReplicatedLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
Expand Down Expand Up @@ -375,6 +406,7 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -387,13 +419,15 @@ def __init__(
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
)
self.mlp = Glm4vVisionMLP(
dim,
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
Expand Down Expand Up @@ -456,24 +490,40 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = d_model
self.proj = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj")
if use_data_parallel:
self.proj = ReplicatedLinear(
input_size=self.hidden_size,
output_size=self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear(
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(
context_dim,
self.hidden_size,
bias=bias,
Expand Down Expand Up @@ -548,14 +598,33 @@ def forward(self, embeddings, lengths, image_shapes, h_coords,
dtype=torch.float32))

# Calculate target dimensions for each patch
target_h = torch.cat([
image_shapes[i, 1].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
target_w = torch.cat([
image_shapes[i, 2].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
# Add bounds checking for data parallel mode
if len(lengths) > image_shapes.shape[0]:
# In data parallel mode, some GPUs might not have all
# image shapes
# Use available image shapes, cycling if necessary
target_h_list = []
target_w_list = []
for i in range(len(lengths)):
# Cycle through available shapes
shape_idx = i % image_shapes.shape[0]
target_h_list.append(image_shapes[shape_idx,
1].repeat(lengths[i]))
target_w_list.append(image_shapes[shape_idx,
2].repeat(lengths[i]))
target_h = torch.cat(target_h_list).to(device=device,
dtype=torch.float32)
target_w = torch.cat(target_w_list).to(device=device,
dtype=torch.float32)
else:
target_h = torch.cat([
image_shapes[i, 1].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
target_w = torch.cat([
image_shapes[i, 2].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)

# Normalize coordinates to [-1, 1] range for grid_sample
h_coords = h_coords.to(device=device, dtype=torch.float32)
Expand Down Expand Up @@ -629,6 +698,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

Expand All @@ -638,6 +708,7 @@ def __init__(
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel

self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
Expand All @@ -661,6 +732,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
) for layer_idx in range(depth)
])
self.merger = Glm4vPatchMerger(
Expand All @@ -669,6 +741,7 @@ def __init__(
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
)
self.embeddings = Glm4vVisionEmbeddings(vision_config)

Expand Down Expand Up @@ -731,8 +804,11 @@ def compute_attn_mask_seqlen(
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
Expand Down Expand Up @@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.visual.": "visual.",
})

supports_encoder_tp_data = True

@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
Expand All @@ -1267,12 +1345,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

if config.model_type == "glm4v":
Expand Down Expand Up @@ -1382,8 +1462,14 @@ def _process_image_input(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw.tolist(),
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw.tolist())
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
Expand All @@ -1393,23 +1479,22 @@ def _process_video_input(
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2

device = self.visual.device
flat_grid_thw = torch.cat([
torch.tensor([[1, h, w]] * t, device=device)
for t, h, w in grid_thw
])
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,
grid_thw=flat_grid_thw)

if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw.tolist(),
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw.tolist())
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size

return video_embeds.split(sizes.tolist())

def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
Expand Down