Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
Expand Down Expand Up @@ -217,17 +218,20 @@ def __init__(
act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.fc1 = ColumnParallelLinear(in_features,
hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel)
self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features,
in_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
Expand Down Expand Up @@ -293,25 +297,28 @@ def __init__(
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size)
num_heads, self.tp_size)

self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
Expand Down Expand Up @@ -453,6 +460,7 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -465,12 +473,14 @@ def __init__(
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)

def forward(
self,
Expand Down Expand Up @@ -531,6 +541,7 @@ def __init__(
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
Expand All @@ -542,13 +553,15 @@ def __init__(
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
prefix=f"{prefix}.mlp.0",
disable_tp=use_data_parallel),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
prefix=f"{prefix}.mlp.2",
disable_tp=use_data_parallel),
])

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -600,6 +613,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()

Expand All @@ -613,6 +627,9 @@ def __init__(
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio

self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.hidden_size

self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
self.embed_dim = embed_dim
Expand All @@ -634,7 +651,8 @@ def __init__(
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
Expand All @@ -643,6 +661,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
Expand All @@ -659,8 +678,9 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
pos_ids = []
max_grid_size = 0
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
Expand All @@ -678,8 +698,8 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
max_grid_size = max(max_grid_size, h, w)
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
Expand All @@ -698,7 +718,7 @@ def compute_attn_mask_seqlen(
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
Expand All @@ -708,8 +728,9 @@ def forward(
rotary_pos_emb = self.rot_pos_emb(grid_thw)

# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
grid_thw_ = torch.tensor(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

Expand Down Expand Up @@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.": "language_model.model.",
})

supports_encoder_tp_data = True

def get_mrope_input_positions(
self,
input_tokens: list[int],
Expand Down Expand Up @@ -1239,6 +1262,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config

self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.multimodal_config = multimodal_config

Expand All @@ -1249,6 +1273,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
else:
self.visual = None
Expand Down Expand Up @@ -1357,7 +1382,15 @@ def _process_image_input(
image_embeds = image_input["image_embeds"]
else:
pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)

# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
Expand All @@ -1377,7 +1410,14 @@ def _process_video_input(
video_embeds = video_input["video_embeds"]
else:
pixel_values_videos = video_input["pixel_values_videos"]
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)

# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
Expand Down
Loading