Skip to content

Commit b794831

Browse files
666even666ywang96
authored andcommitted
[Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
Signed-off-by: Yiwen Chen <yiwen66@berkeley.edu> Signed-off-by: YiwenC <54658925+666even666@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent 0fb5c4b commit b794831

File tree

3 files changed

+80
-33
lines changed

3 files changed

+80
-33
lines changed

docs/configuration/optimization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
175175
Known supported models:
176176

177177
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
178+
- InternVL (<gh-pr:23909>)
178179
- Kimi-VL (<gh-pr:23817>)
179180
- Llama4 (<gh-pr:18368>)
180181
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)

vllm/model_executor/models/intern_vit.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from vllm.model_executor.layers.layernorm import RMSNorm
2626
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2727
QKVParallelLinear,
28+
ReplicatedLinear,
2829
RowParallelLinear)
2930
from vllm.model_executor.layers.quantization import QuantizationConfig
3031
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32+
from vllm.multimodal.utils import run_dp_sharded_vision_model
3133

3234
NORM2FN = {
3335
'rms_norm': RMSNorm,
@@ -137,6 +139,7 @@ def __init__(
137139
*,
138140
num_dummy_heads: int = 0,
139141
prefix: str = "",
142+
use_data_parallel: bool = False,
140143
) -> None:
141144
super().__init__()
142145

@@ -150,23 +153,34 @@ def __init__(
150153
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
151154
f' {self.num_heads}).')
152155

153-
self.tp_size = get_tensor_model_parallel_world_size()
154-
self.tp_rank = get_tensor_model_parallel_rank()
156+
self.tp_size = (1 if use_data_parallel else
157+
get_tensor_model_parallel_world_size())
158+
self.tp_rank = (0 if use_data_parallel else
159+
get_tensor_model_parallel_rank())
155160

156161
# Additional dummy heads are used to enable TP for common GPU counts.
157162
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
158163
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
159164
self.tp_size)
160165

161166
self.scale = self.head_dim**-0.5
162-
self.qkv = QKVParallelLinear(
163-
self.embed_dim,
164-
self.head_dim,
165-
num_dummy_heads + self.num_heads,
166-
bias=config.qkv_bias,
167-
quant_config=quant_config,
168-
prefix=f"{prefix}.qkv",
169-
)
167+
if use_data_parallel:
168+
self.qkv = ReplicatedLinear(
169+
self.embed_dim,
170+
3 * self.head_dim * self.num_heads,
171+
bias=config.qkv_bias,
172+
quant_config=quant_config,
173+
prefix=f"{prefix}.qkv",
174+
)
175+
else:
176+
self.qkv = QKVParallelLinear(
177+
self.embed_dim,
178+
self.head_dim,
179+
num_dummy_heads + self.num_heads,
180+
bias=config.qkv_bias,
181+
quant_config=quant_config,
182+
prefix=f"{prefix}.qkv",
183+
)
170184

171185
self.qk_normalization = config.qk_normalization
172186

@@ -178,12 +192,20 @@ def __init__(
178192
eps=config.layer_norm_eps,
179193
var_hidden_size=self.embed_dim)
180194

181-
self.proj = RowParallelLinear(
182-
self.dummy_dim,
183-
self.embed_dim,
184-
quant_config=quant_config,
185-
prefix=f"{prefix}.proj",
186-
)
195+
if use_data_parallel:
196+
self.proj = ReplicatedLinear(
197+
self.dummy_dim,
198+
self.embed_dim,
199+
quant_config=quant_config,
200+
prefix=f"{prefix}.proj",
201+
)
202+
else:
203+
self.proj = RowParallelLinear(
204+
self.dummy_dim,
205+
self.embed_dim,
206+
quant_config=quant_config,
207+
prefix=f"{prefix}.proj",
208+
)
187209

188210
self.attn = MultiHeadAttention(self.num_heads_per_partition,
189211
self.head_dim, self.scale)
@@ -287,21 +309,26 @@ def __init__(
287309
config: PretrainedConfig,
288310
quant_config: Optional[QuantizationConfig] = None,
289311
prefix: str = "",
312+
use_data_parallel: bool = False,
290313
) -> None:
291314
super().__init__()
292315

293316
self.config = config
294317
self.activation_fn = get_act_fn(config.hidden_act)
295-
self.fc1 = ColumnParallelLinear(config.hidden_size,
296-
config.intermediate_size,
297-
bias=True,
298-
quant_config=quant_config,
299-
prefix=f"{prefix}.fc1")
300-
self.fc2 = RowParallelLinear(config.intermediate_size,
301-
config.hidden_size,
302-
bias=True,
303-
quant_config=quant_config,
304-
prefix=f"{prefix}.fc2")
318+
cls_fc1 = (ReplicatedLinear
319+
if use_data_parallel else ColumnParallelLinear)
320+
self.fc1 = cls_fc1(config.hidden_size,
321+
config.intermediate_size,
322+
bias=True,
323+
quant_config=quant_config,
324+
prefix=f"{prefix}.fc1")
325+
cls_fc2 = (ReplicatedLinear
326+
if use_data_parallel else RowParallelLinear)
327+
self.fc2 = cls_fc2(config.intermediate_size,
328+
config.hidden_size,
329+
bias=True,
330+
quant_config=quant_config,
331+
prefix=f"{prefix}.fc2")
305332

306333
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
307334
hidden_states, _ = self.fc1(hidden_states)
@@ -320,6 +347,7 @@ def __init__(
320347
*,
321348
num_dummy_heads: int = 0,
322349
prefix: str = "",
350+
use_data_parallel: bool = False,
323351
) -> None:
324352
super().__init__()
325353

@@ -330,11 +358,13 @@ def __init__(
330358
self.attn = self._init_attn(config,
331359
quant_config,
332360
num_dummy_heads=num_dummy_heads,
333-
prefix=f"{prefix}.attn")
361+
prefix=f"{prefix}.attn",
362+
use_data_parallel=use_data_parallel)
334363

335364
self.mlp = InternMLP(config,
336365
quant_config=quant_config,
337-
prefix=f"{prefix}.mlp")
366+
prefix=f"{prefix}.mlp",
367+
use_data_parallel=use_data_parallel)
338368
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
339369
eps=config.layer_norm_eps)
340370
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@@ -352,16 +382,20 @@ def _init_attn(
352382
*,
353383
num_dummy_heads: int,
354384
prefix: str = "",
385+
use_data_parallel: bool = False,
355386
):
356387
# fallback to sdpa attention if tp unavailable
357-
tp_size = get_tensor_model_parallel_world_size()
388+
# tp_size = get_tensor_model_parallel_world_size()
389+
tp_size = (1 if use_data_parallel else
390+
get_tensor_model_parallel_world_size())
358391
num_heads = config.num_attention_heads
359392

360393
if (num_heads + num_dummy_heads) % tp_size == 0:
361394
return InternParallelAttention(config,
362395
quant_config=quant_config,
363396
num_dummy_heads=num_dummy_heads,
364-
prefix=prefix)
397+
prefix=prefix,
398+
use_data_parallel=use_data_parallel)
365399

366400
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
367401

@@ -388,6 +422,7 @@ def __init__(
388422
num_hidden_layers_override: Optional[int] = None,
389423
num_dummy_heads: int = 0,
390424
prefix: str = "",
425+
use_data_parallel: bool = False,
391426
):
392427
super().__init__()
393428

@@ -402,7 +437,8 @@ def __init__(
402437
InternVisionEncoderLayer(config,
403438
quant_config,
404439
num_dummy_heads=num_dummy_heads,
405-
prefix=f"{prefix}.layers.{layer_idx}")
440+
prefix=f"{prefix}.layers.{layer_idx}",
441+
use_data_parallel=use_data_parallel)
406442
for layer_idx in range(num_hidden_layers)
407443
])
408444

@@ -429,10 +465,12 @@ def __init__(
429465
num_hidden_layers_override: Optional[int] = None,
430466
num_dummy_heads: int = 0,
431467
prefix: str = "",
468+
use_data_parallel: bool = False,
432469
) -> None:
433470
super().__init__()
434471

435472
self.config = config
473+
self.use_data_parallel = use_data_parallel
436474

437475
self.embeddings = InternVisionEmbeddings(config)
438476
self.encoder = InternVisionEncoder(
@@ -441,6 +479,7 @@ def __init__(
441479
num_hidden_layers_override=num_hidden_layers_override,
442480
num_dummy_heads=num_dummy_heads,
443481
prefix=f"{prefix}.encoder",
482+
use_data_parallel=use_data_parallel,
444483
)
445484

446485
def get_input_embeddings(self):
@@ -464,7 +503,11 @@ def forward(
464503
raise ValueError(
465504
f'wrong pixel_values size: {pixel_values.shape}')
466505

467-
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
506+
if self.use_data_parallel:
507+
encoder_outputs = run_dp_sharded_vision_model(
508+
hidden_states, self.encoder)
509+
else:
510+
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
468511

469512
return encoder_outputs
470513

vllm/model_executor/models/internvl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,8 @@ def get_video_replacement_internvl(item_idx: int):
10351035
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
10361036
SupportsLoRA):
10371037

1038+
supports_encoder_tp_data = True
1039+
10381040
@classmethod
10391041
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
10401042
if modality.startswith("image"):
@@ -1053,6 +1055,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10531055

10541056
self.config = config
10551057
self.multimodal_config = multimodal_config
1058+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
10561059
self._patch_quant_config(config, quant_config)
10571060

10581061
image_size = config.force_image_size or config.vision_config.image_size
@@ -1120,7 +1123,7 @@ def _init_vision_model(
11201123
quant_config=quant_config,
11211124
num_hidden_layers_override=num_hidden_layers,
11221125
prefix=prefix,
1123-
)
1126+
use_data_parallel=self.use_data_parallel)
11241127
else:
11251128
return InternVisionPatchModel(config.vision_config)
11261129

0 commit comments

Comments
 (0)