Skip to content

Commit 90e07ab

Browse files
committed
Revert "[Core] Allow disabling TP sharding for parallel Linear layer (vllm-project#23024)"
This reverts commit 53b19cc.
1 parent f856220 commit 90e07ab

File tree

7 files changed

+282
-205
lines changed

7 files changed

+282
-205
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 105 additions & 70 deletions
Large diffs are not rendered by default.

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(self, load_config: LoadConfig):
6969
# Store all module names (from transformers) that support
7070
# BNB quantization.
7171
self.target_modules: list[str] = []
72-
self.tp_disabled_modules: list[str] = []
7372
# Store the mapping of expert parameters for MoE models.
7473
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
7574
# mapping weight names from transformers to vllm.
@@ -323,24 +322,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
323322
quant_state_dict) -> Generator:
324323
from bitsandbytes.functional import quantize_4bit
325324

326-
global_tp_size = get_tensor_model_parallel_world_size()
327-
global_tp_rank = get_tensor_model_parallel_rank()
325+
tp_size = get_tensor_model_parallel_world_size()
326+
tp_rank = get_tensor_model_parallel_rank()
328327

329328
for (
330329
org_weight_name,
331330
mapped_weight_name,
332331
weight_tensor,
333332
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
334-
335-
# override tp_size and tp_rank if the module has disabled TP
336-
if any(tp_disabled_module in mapped_weight_name
337-
for tp_disabled_module in self.tp_disabled_modules):
338-
tp_size = 1
339-
tp_rank = 0
340-
else:
341-
tp_size = global_tp_size
342-
tp_rank = global_tp_rank
343-
344333
if any(target_module in mapped_weight_name
345334
for target_module in self.target_modules
346335
) and mapped_weight_name.endswith(".weight"):
@@ -429,16 +418,12 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
429418
# Map vllm's names to transformers's names.
430419
rep_name, sub_modules = modules_info
431420
for sub_name in sub_modules:
432-
new_name = name.replace(rep_name, sub_name)
433-
self.target_modules.append(new_name)
434-
if module.disable_tp:
435-
self.tp_disabled_modules.append(new_name)
421+
self.target_modules.append(
422+
name.replace(rep_name, sub_name))
436423
# Add original module name even if the module has stacked map,
437424
# in case model has a mixture of disk-merged and disk-split
438425
# weights with same last name.
439426
self.target_modules.append(name)
440-
if module.disable_tp:
441-
self.tp_disabled_modules.append(name)
442427
elif isinstance(module, FusedMoE) and hasattr(
443428
module.quant_method, "quant_config"):
444429
# TODO: support FusedMoE with prequant and 8bit.

vllm/model_executor/models/deepseek_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.model_executor.layers.layernorm import RMSNorm
4444
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4545
MergedColumnParallelLinear,
46+
MergedReplicatedLinear,
4647
ReplicatedLinear,
4748
RowParallelLinear)
4849
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -434,13 +435,12 @@ def __init__(
434435
self.max_position_embeddings = max_position_embeddings
435436

436437
if self.q_lora_rank is not None:
437-
self.fused_qkv_a_proj = MergedColumnParallelLinear(
438+
self.fused_qkv_a_proj = MergedReplicatedLinear(
438439
self.hidden_size,
439440
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
440441
bias=False,
441442
quant_config=quant_config,
442-
prefix=f"{prefix}.fused_qkv_a_proj",
443-
disable_tp=True)
443+
prefix=f"{prefix}.fused_qkv_a_proj")
444444
else:
445445
self.kv_a_proj_with_mqa = ReplicatedLinear(
446446
self.hidden_size,

vllm/model_executor/models/glm4_1v.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@
5151
from vllm.logger import init_logger
5252
from vllm.model_executor import SamplingMetadata
5353
from vllm.model_executor.layers.layernorm import RMSNorm
54+
# yapf: disable
5455
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5556
MergedColumnParallelLinear,
57+
MergedReplicatedLinear,
5658
QKVParallelLinear,
59+
ReplicatedLinear,
5760
RowParallelLinear)
61+
# yapf: enable
5862
from vllm.model_executor.layers.quantization import QuantizationConfig
5963
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
6064
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -170,22 +174,20 @@ def __init__(
170174
use_data_parallel: bool = False,
171175
):
172176
super().__init__()
173-
self.gate_up_proj = MergedColumnParallelLinear(
174-
input_size=in_features,
175-
output_sizes=[hidden_features] * 2,
176-
bias=bias,
177-
quant_config=quant_config,
178-
prefix=f"{prefix}.gate_up_proj",
179-
disable_tp=use_data_parallel,
180-
)
181-
self.down_proj = RowParallelLinear(
182-
hidden_features,
183-
in_features,
184-
bias=bias,
185-
quant_config=quant_config,
186-
prefix=f"{prefix}.down_proj",
187-
disable_tp=use_data_parallel,
188-
)
177+
cls_gate_up = (MergedReplicatedLinear
178+
if use_data_parallel else MergedColumnParallelLinear)
179+
self.gate_up_proj = cls_gate_up(input_size=in_features,
180+
output_sizes=[hidden_features] * 2,
181+
bias=bias,
182+
quant_config=quant_config,
183+
prefix=f"{prefix}.gate_up_proj")
184+
cls_down = (ReplicatedLinear
185+
if use_data_parallel else RowParallelLinear)
186+
self.down_proj = cls_down(hidden_features,
187+
in_features,
188+
bias=bias,
189+
quant_config=quant_config,
190+
prefix=f"{prefix}.down_proj")
189191
self.act_fn = SiluAndMul()
190192

191193
def forward(self, x: torch.Tensor):
@@ -232,32 +234,48 @@ def __init__(
232234
# Per attention head and per partition values.
233235
self.tp_size = (1 if use_data_parallel else
234236
get_tensor_model_parallel_world_size())
235-
self.tp_rank = (0 if use_data_parallel else
236-
parallel_state.get_tensor_model_parallel_rank())
237+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
237238
self.hidden_size_per_attention_head = dist_utils.divide(
238239
projection_size, num_heads)
239240
self.num_attention_heads_per_partition = dist_utils.divide(
240241
num_heads, self.tp_size)
241242

242-
self.qkv = QKVParallelLinear(
243-
hidden_size=embed_dim,
244-
head_size=self.hidden_size_per_attention_head,
245-
total_num_heads=num_heads,
246-
total_num_kv_heads=num_heads,
247-
bias=False,
248-
quant_config=quant_config,
249-
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
250-
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
251-
disable_tp=use_data_parallel,
252-
)
253-
self.proj = RowParallelLinear(
254-
input_size=projection_size,
255-
output_size=embed_dim,
256-
quant_config=quant_config,
257-
prefix=f"{prefix}.proj",
258-
bias=False,
259-
disable_tp=use_data_parallel,
260-
)
243+
if use_data_parallel:
244+
self.qkv = ReplicatedLinear(
245+
input_size=embed_dim,
246+
output_size=3 * projection_size,
247+
bias=False,
248+
quant_config=quant_config,
249+
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
250+
prefix=f"{prefix}.qkv_proj"
251+
if quant_config else f"{prefix}.qkv",
252+
)
253+
self.proj = ReplicatedLinear(
254+
input_size=projection_size,
255+
output_size=embed_dim,
256+
quant_config=quant_config,
257+
prefix=f"{prefix}.proj",
258+
bias=False,
259+
)
260+
else:
261+
self.qkv = QKVParallelLinear(
262+
hidden_size=embed_dim,
263+
head_size=self.hidden_size_per_attention_head,
264+
total_num_heads=num_heads,
265+
total_num_kv_heads=num_heads,
266+
bias=False,
267+
quant_config=quant_config,
268+
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
269+
prefix=f"{prefix}.qkv_proj"
270+
if quant_config else f"{prefix}.qkv",
271+
)
272+
self.proj = RowParallelLinear(
273+
input_size=projection_size,
274+
output_size=embed_dim,
275+
quant_config=quant_config,
276+
prefix=f"{prefix}.proj",
277+
bias=False,
278+
)
261279

262280
# Detect attention implementation.
263281
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -476,31 +494,41 @@ def __init__(
476494
) -> None:
477495
super().__init__()
478496
self.hidden_size = d_model
479-
self.proj = ColumnParallelLinear(
480-
self.hidden_size,
481-
self.hidden_size,
482-
bias=bias,
483-
gather_output=True,
484-
quant_config=quant_config,
485-
prefix=f"{prefix}.proj",
486-
disable_tp=use_data_parallel,
487-
)
497+
if use_data_parallel:
498+
self.proj = ReplicatedLinear(
499+
input_size=self.hidden_size,
500+
output_size=self.hidden_size,
501+
bias=bias,
502+
quant_config=quant_config,
503+
prefix=f"{prefix}.proj",
504+
)
505+
else:
506+
self.proj = ColumnParallelLinear(
507+
self.hidden_size,
508+
self.hidden_size,
509+
bias=bias,
510+
gather_output=True,
511+
quant_config=quant_config,
512+
prefix=f"{prefix}.proj",
513+
)
488514
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
489-
self.gate_up_proj = MergedColumnParallelLinear(
515+
cls_gate_up = (MergedReplicatedLinear
516+
if use_data_parallel else MergedColumnParallelLinear)
517+
self.gate_up_proj = cls_gate_up(
490518
input_size=self.hidden_size,
491519
output_sizes=[context_dim] * 2,
492520
bias=bias,
493521
quant_config=quant_config,
494522
prefix=f"{prefix}.gate_up_proj",
495-
disable_tp=use_data_parallel,
496523
)
497-
self.down_proj = RowParallelLinear(
524+
cls_down = (ReplicatedLinear
525+
if use_data_parallel else RowParallelLinear)
526+
self.down_proj = cls_down(
498527
context_dim,
499528
self.hidden_size,
500529
bias=bias,
501530
quant_config=quant_config,
502531
prefix=f"{prefix}.down_proj",
503-
disable_tp=use_data_parallel,
504532
)
505533
self.act_fn = SiluAndMul()
506534
self.extra_activation_func = nn.GELU()

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
# yapf: disable
4949
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5050
MergedColumnParallelLinear,
51+
MergedReplicatedLinear,
5152
QKVParallelLinear,
5253
ReplicatedLinear,
5354
RowParallelLinear)
@@ -177,20 +178,22 @@ def __init__(self,
177178
prefix: str = "",
178179
use_data_parallel: bool = False):
179180
super().__init__()
180-
self.gate_up_proj = MergedColumnParallelLinear(
181+
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
182+
MergedColumnParallelLinear)
183+
self.gate_up_proj = cls_gate_up_proj(
181184
input_size=in_features,
182185
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
183186
bias=bias,
184187
quant_config=quant_config,
185-
prefix=f"{prefix}.gate_up_proj",
186-
disable_tp=use_data_parallel)
187-
188-
self.down_proj = RowParallelLinear(hidden_features,
189-
in_features,
190-
bias=bias,
191-
quant_config=quant_config,
192-
prefix=f"{prefix}.down_proj",
193-
disable_tp=use_data_parallel)
188+
prefix=f"{prefix}.gate_up_proj")
189+
190+
cls_down_proj = (ReplicatedLinear
191+
if use_data_parallel else RowParallelLinear)
192+
self.down_proj = cls_down_proj(hidden_features,
193+
in_features,
194+
bias=bias,
195+
quant_config=quant_config,
196+
prefix=f"{prefix}.down_proj")
194197
self.act_fn = act_fn
195198

196199
def forward(self, x: torch.Tensor):
@@ -240,21 +243,30 @@ def __init__(
240243
self.num_attention_heads_per_partition = dist_utils.divide(
241244
num_heads, self.tp_size)
242245

243-
self.qkv = QKVParallelLinear(
244-
hidden_size=embed_dim,
245-
head_size=self.hidden_size_per_attention_head,
246-
total_num_heads=num_heads,
247-
total_num_kv_heads=num_heads,
248-
bias=True,
249-
quant_config=quant_config,
250-
prefix=f"{prefix}.qkv",
251-
disable_tp=use_data_parallel)
252-
253-
self.proj = RowParallelLinear(input_size=projection_size,
254-
output_size=embed_dim,
255-
quant_config=quant_config,
256-
prefix=f"{prefix}.proj",
257-
disable_tp=use_data_parallel)
246+
if use_data_parallel:
247+
self.qkv = ReplicatedLinear(embed_dim,
248+
self.hidden_size_per_attention_head *
249+
3 * num_heads,
250+
bias=True,
251+
quant_config=quant_config,
252+
prefix=f"{prefix}.qkv")
253+
254+
else:
255+
self.qkv = QKVParallelLinear(
256+
hidden_size=embed_dim,
257+
head_size=self.hidden_size_per_attention_head,
258+
total_num_heads=num_heads,
259+
total_num_kv_heads=num_heads,
260+
bias=True,
261+
quant_config=quant_config,
262+
prefix=f"{prefix}.qkv")
263+
264+
cls_proj = (ReplicatedLinear
265+
if use_data_parallel else RowParallelLinear)
266+
self.proj = cls_proj(input_size=projection_size,
267+
output_size=embed_dim,
268+
quant_config=quant_config,
269+
prefix=f"{prefix}.proj")
258270

259271
# Detect attention implementation.
260272
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

0 commit comments

Comments
 (0)