Skip to content

Commit 382bde7

Browse files
david6666666CarrotShoo
authored andcommitted
[Model] Support DP for ViT on MiniCPM-V-4 (vllm-project#23327)
Signed-off-by: ycyaw66 <497410282@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com>
1 parent ddf0d07 commit 382bde7

File tree

4 files changed

+105
-30
lines changed

4 files changed

+105
-30
lines changed

docs/configuration/optimization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
172172
Currently, the following models support `mm_encoder_tp_mode="data"`:
173173

174174
- Llama4 (<gh-pr:18368>)
175+
- MiniCPM-V-4 (<gh-pr:23327>)
175176
- Qwen2.5-VL (<gh-pr:22742>)
176177
- Step3 (<gh-pr:22697>)
177178

vllm/model_executor/models/idefics2_vision_model.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
Idefics2Config, Idefics2VisionConfig)
2828

2929
from vllm.attention.layer import MultiHeadAttention
30-
from vllm.distributed import divide, get_tensor_model_parallel_world_size
30+
from vllm.distributed import get_tensor_model_parallel_world_size
3131
from vllm.model_executor.layers.activation import get_act_fn
3232
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3333
QKVParallelLinear,
34+
ReplicatedLinear,
3435
RowParallelLinear)
3536
from vllm.model_executor.layers.quantization import QuantizationConfig
3637
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38+
from vllm.multimodal.utils import run_dp_sharded_vision_model
3739

3840

3941
class Idefics2VisionEmbeddings(nn.Module):
@@ -118,6 +120,7 @@ def __init__(
118120
config: Idefics2VisionConfig,
119121
quant_config: Optional[QuantizationConfig] = None,
120122
prefix: str = "",
123+
use_data_parallel: bool = False,
121124
) -> None:
122125
super().__init__()
123126
self.config = config
@@ -130,22 +133,43 @@ def __init__(
130133
f" {self.num_heads}).")
131134
self.scale = self.head_dim**-0.5
132135
self.dropout = config.attention_dropout
133-
self.qkv_proj = QKVParallelLinear(
134-
self.embed_dim,
135-
self.head_dim,
136-
self.num_heads,
137-
quant_config=quant_config,
138-
prefix=f"{prefix}.qkv_proj",
139-
)
140-
self.out_proj = RowParallelLinear(
141-
self.embed_dim,
142-
self.embed_dim,
143-
bias=True,
144-
quant_config=quant_config,
145-
prefix=f"{prefix}.out_proj",
146-
)
147-
self.tp_size = get_tensor_model_parallel_world_size()
148-
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
136+
137+
tp_size = (1 if use_data_parallel else
138+
get_tensor_model_parallel_world_size())
139+
assert self.num_heads % tp_size == 0
140+
self.num_heads_per_partition = self.num_heads // tp_size
141+
142+
if use_data_parallel:
143+
self.q_size = self.num_heads * self.head_dim
144+
self.qkv_proj = ReplicatedLinear(
145+
self.embed_dim,
146+
3 * self.q_size,
147+
bias=True,
148+
quant_config=quant_config,
149+
prefix=f"{prefix}.qkv_proj",
150+
)
151+
self.out_proj = ReplicatedLinear(
152+
self.embed_dim,
153+
self.embed_dim,
154+
bias=True,
155+
quant_config=quant_config,
156+
prefix=f"{prefix}.out_proj",
157+
)
158+
else:
159+
self.qkv_proj = QKVParallelLinear(
160+
self.embed_dim,
161+
self.head_dim,
162+
self.num_heads,
163+
quant_config=quant_config,
164+
prefix=f"{prefix}.qkv_proj",
165+
)
166+
self.out_proj = RowParallelLinear(
167+
self.embed_dim,
168+
self.embed_dim,
169+
bias=True,
170+
quant_config=quant_config,
171+
prefix=f"{prefix}.out_proj",
172+
)
149173
self.attn = MultiHeadAttention(self.num_heads_per_partition,
150174
self.head_dim, self.scale)
151175

@@ -169,18 +193,23 @@ def __init__(
169193
config: Idefics2VisionConfig,
170194
quant_config: Optional[QuantizationConfig] = None,
171195
prefix: str = "",
196+
use_data_parallel: bool = False,
172197
) -> None:
173198
super().__init__()
174199
self.config = config
175200
self.activation_fn = get_act_fn(config.hidden_act)
176-
self.fc1 = ColumnParallelLinear(
201+
cls_fc1 = (ReplicatedLinear
202+
if use_data_parallel else ColumnParallelLinear)
203+
self.fc1 = cls_fc1(
177204
config.hidden_size,
178205
config.intermediate_size,
179206
bias=True,
180207
quant_config=quant_config,
181208
prefix=f"{prefix}.fc1",
182209
)
183-
self.fc2 = RowParallelLinear(
210+
cls_fc2 = (ReplicatedLinear
211+
if use_data_parallel else RowParallelLinear)
212+
self.fc2 = cls_fc2(
184213
config.intermediate_size,
185214
config.hidden_size,
186215
bias=True,
@@ -202,17 +231,21 @@ def __init__(
202231
config: Idefics2Config,
203232
quant_config: Optional[QuantizationConfig] = None,
204233
prefix: str = "",
234+
use_data_parallel: bool = False,
205235
) -> None:
206236
super().__init__()
207237
self.embed_dim = config.hidden_size
208-
self.self_attn = Idefics2VisionAttention(config,
209-
quant_config=quant_config,
210-
prefix=f"{prefix}.self_attn")
238+
self.self_attn = Idefics2VisionAttention(
239+
config,
240+
quant_config=quant_config,
241+
prefix=f"{prefix}.self_attn",
242+
use_data_parallel=use_data_parallel)
211243
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
212244
eps=config.layer_norm_eps)
213245
self.mlp = Idefics2VisionMLP(config,
214246
quant_config=quant_config,
215-
prefix=f"{prefix}.mlp")
247+
prefix=f"{prefix}.mlp",
248+
use_data_parallel=use_data_parallel)
216249
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
217250
eps=config.layer_norm_eps)
218251

@@ -254,6 +287,7 @@ def __init__(
254287
*,
255288
num_hidden_layers_override: Optional[int] = None,
256289
prefix: str = "",
290+
use_data_parallel: bool = False,
257291
) -> None:
258292
super().__init__()
259293

@@ -267,7 +301,8 @@ def __init__(
267301
self.layers = nn.ModuleList([
268302
Idefics2EncoderLayer(config,
269303
quant_config=quant_config,
270-
prefix=f"{prefix}.layers.{layer_idx}")
304+
prefix=f"{prefix}.layers.{layer_idx}",
305+
use_data_parallel=use_data_parallel)
271306
for layer_idx in range(num_hidden_layers)
272307
])
273308

@@ -301,17 +336,20 @@ def __init__(
301336
num_hidden_layers_override: Optional[int] = None,
302337
require_post_norm: bool = True,
303338
prefix: str = "",
339+
use_data_parallel: bool = False,
304340
) -> None:
305341
super().__init__()
306342

307343
embed_dim = config.hidden_size
308344
self.config = config
345+
self.use_data_parallel = use_data_parallel
309346
self.embeddings = Idefics2VisionEmbeddings(config)
310347
self.encoder = Idefics2Encoder(
311348
config,
312349
quant_config=quant_config,
313350
num_hidden_layers_override=num_hidden_layers_override,
314-
prefix=f"{prefix}.encoder")
351+
prefix=f"{prefix}.encoder",
352+
use_data_parallel=use_data_parallel)
315353

316354
num_hidden_layers = config.num_hidden_layers
317355
if len(self.encoder.layers) > config.num_hidden_layers:
@@ -340,10 +378,38 @@ def forward(
340378
patch_attention_mask=patch_attention_mask,
341379
tgt_sizes=tgt_sizes,
342380
)
343-
encoder_outputs = self.encoder(hidden_states)
381+
if self.use_data_parallel:
382+
encoder_outputs = run_dp_sharded_vision_model(
383+
hidden_states, self.encoder)
384+
else:
385+
encoder_outputs = self.encoder(hidden_states)
344386
last_hidden_state = self.post_layernorm(encoder_outputs)
345387
return last_hidden_state
346388

389+
def _consolidate_qkv_weights(
390+
self, weights: Iterable[tuple[str, torch.Tensor]]
391+
) -> Iterable[tuple[str, torch.Tensor]]:
392+
qkv_idx_mappings = {
393+
".self_attn.q_proj": 0,
394+
".self_attn.k_proj": 1,
395+
".self_attn.v_proj": 2,
396+
}
397+
qkv_weights = {}
398+
for name, loaded_weight in weights:
399+
for weight_name, idx in qkv_idx_mappings.items():
400+
if weight_name not in name:
401+
continue
402+
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
403+
if new_name not in qkv_weights:
404+
qkv_weights[new_name] = [None] * 3
405+
qkv_weights[new_name][idx] = loaded_weight
406+
break
407+
else:
408+
yield name, loaded_weight
409+
for key, weight in qkv_weights.items():
410+
qkv_weight = torch.cat(weight, dim=0)
411+
yield key, qkv_weight
412+
347413
def load_weights(self, weights: Iterable[tuple[str,
348414
torch.Tensor]]) -> set[str]:
349415
stacked_params_mapping = [
@@ -356,6 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str,
356422
loaded_params: set[str] = set()
357423
layer_count = len(self.encoder.layers)
358424

425+
if self.use_data_parallel:
426+
weights = self._consolidate_qkv_weights(weights)
427+
359428
for name, loaded_weight in weights:
360429
# skip pooling header
361430
if name.startswith("head."):
@@ -373,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str,
373442
continue
374443

375444
for param_name, weight_name, shard_id in stacked_params_mapping:
376-
if weight_name not in name:
445+
if weight_name not in name or self.use_data_parallel:
377446
continue
378447
name = name.replace(weight_name, param_name)
379448
param = params_dict[name]

vllm/model_executor/models/minicpmv.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
778778
# and config class
779779
self.config = config
780780
self.multimodal_config = multimodal_config
781+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
781782

782783
self.version = get_version_by_config(self.config)
783784
self.llm = self.init_llm(vllm_config=vllm_config,
@@ -1325,9 +1326,11 @@ def init_vision_module(
13251326
prefix: str = "",
13261327
) -> nn.Module:
13271328
quant_config = self._maybe_ignore_quant_config(quant_config)
1328-
model = Idefics2VisionTransformer(config.vision_config,
1329-
quant_config=quant_config,
1330-
prefix=prefix)
1329+
model = Idefics2VisionTransformer(
1330+
config.vision_config,
1331+
quant_config=quant_config,
1332+
prefix=prefix,
1333+
use_data_parallel=self.use_data_parallel)
13311334
if self.config.drop_vision_last_layer:
13321335
model.encoder.layers = model.encoder.layers[:-1]
13331336
return model

vllm/multimodal/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
461461
num_chunks_per_rank, ...]
462462

463463
vision_embeddings = vision_model(image_input_per_rank)
464+
# Ensure tensor is contiguous before all_gather
465+
vision_embeddings = vision_embeddings.contiguous()
464466
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
465467
dim=0)
466468
vision_embeddings = vision_embeddings[:num_chunks, ...]

0 commit comments

Comments
 (0)