|
51 | 51 | from vllm.logger import init_logger |
52 | 52 | from vllm.model_executor import SamplingMetadata |
53 | 53 | from vllm.model_executor.layers.layernorm import RMSNorm |
54 | | -# yapf: disable |
55 | 54 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
56 | 55 | MergedColumnParallelLinear, |
57 | | - MergedReplicatedLinear, |
58 | 56 | QKVParallelLinear, |
59 | | - ReplicatedLinear, |
60 | 57 | RowParallelLinear) |
61 | | -# yapf: enable |
62 | 58 | from vllm.model_executor.layers.quantization import QuantizationConfig |
63 | 59 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
64 | 60 | from vllm.model_executor.models.module_mapping import MultiModelKeys |
@@ -174,20 +170,22 @@ def __init__( |
174 | 170 | use_data_parallel: bool = False, |
175 | 171 | ): |
176 | 172 | super().__init__() |
177 | | - cls_gate_up = (MergedReplicatedLinear |
178 | | - if use_data_parallel else MergedColumnParallelLinear) |
179 | | - self.gate_up_proj = cls_gate_up(input_size=in_features, |
180 | | - output_sizes=[hidden_features] * 2, |
181 | | - bias=bias, |
182 | | - quant_config=quant_config, |
183 | | - prefix=f"{prefix}.gate_up_proj") |
184 | | - cls_down = (ReplicatedLinear |
185 | | - if use_data_parallel else RowParallelLinear) |
186 | | - self.down_proj = cls_down(hidden_features, |
187 | | - in_features, |
188 | | - bias=bias, |
189 | | - quant_config=quant_config, |
190 | | - prefix=f"{prefix}.down_proj") |
| 173 | + self.gate_up_proj = MergedColumnParallelLinear( |
| 174 | + input_size=in_features, |
| 175 | + output_sizes=[hidden_features] * 2, |
| 176 | + bias=bias, |
| 177 | + quant_config=quant_config, |
| 178 | + prefix=f"{prefix}.gate_up_proj", |
| 179 | + disable_tp=use_data_parallel, |
| 180 | + ) |
| 181 | + self.down_proj = RowParallelLinear( |
| 182 | + hidden_features, |
| 183 | + in_features, |
| 184 | + bias=bias, |
| 185 | + quant_config=quant_config, |
| 186 | + prefix=f"{prefix}.down_proj", |
| 187 | + disable_tp=use_data_parallel, |
| 188 | + ) |
191 | 189 | self.act_fn = SiluAndMul() |
192 | 190 |
|
193 | 191 | def forward(self, x: torch.Tensor): |
@@ -234,48 +232,32 @@ def __init__( |
234 | 232 | # Per attention head and per partition values. |
235 | 233 | self.tp_size = (1 if use_data_parallel else |
236 | 234 | get_tensor_model_parallel_world_size()) |
237 | | - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() |
| 235 | + self.tp_rank = (0 if use_data_parallel else |
| 236 | + parallel_state.get_tensor_model_parallel_rank()) |
238 | 237 | self.hidden_size_per_attention_head = dist_utils.divide( |
239 | 238 | projection_size, num_heads) |
240 | 239 | self.num_attention_heads_per_partition = dist_utils.divide( |
241 | 240 | num_heads, self.tp_size) |
242 | 241 |
|
243 | | - if use_data_parallel: |
244 | | - self.qkv = ReplicatedLinear( |
245 | | - input_size=embed_dim, |
246 | | - output_size=3 * projection_size, |
247 | | - bias=False, |
248 | | - quant_config=quant_config, |
249 | | - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
250 | | - prefix=f"{prefix}.qkv_proj" |
251 | | - if quant_config else f"{prefix}.qkv", |
252 | | - ) |
253 | | - self.proj = ReplicatedLinear( |
254 | | - input_size=projection_size, |
255 | | - output_size=embed_dim, |
256 | | - quant_config=quant_config, |
257 | | - prefix=f"{prefix}.proj", |
258 | | - bias=False, |
259 | | - ) |
260 | | - else: |
261 | | - self.qkv = QKVParallelLinear( |
262 | | - hidden_size=embed_dim, |
263 | | - head_size=self.hidden_size_per_attention_head, |
264 | | - total_num_heads=num_heads, |
265 | | - total_num_kv_heads=num_heads, |
266 | | - bias=False, |
267 | | - quant_config=quant_config, |
268 | | - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
269 | | - prefix=f"{prefix}.qkv_proj" |
270 | | - if quant_config else f"{prefix}.qkv", |
271 | | - ) |
272 | | - self.proj = RowParallelLinear( |
273 | | - input_size=projection_size, |
274 | | - output_size=embed_dim, |
275 | | - quant_config=quant_config, |
276 | | - prefix=f"{prefix}.proj", |
277 | | - bias=False, |
278 | | - ) |
| 242 | + self.qkv = QKVParallelLinear( |
| 243 | + hidden_size=embed_dim, |
| 244 | + head_size=self.hidden_size_per_attention_head, |
| 245 | + total_num_heads=num_heads, |
| 246 | + total_num_kv_heads=num_heads, |
| 247 | + bias=False, |
| 248 | + quant_config=quant_config, |
| 249 | + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
| 250 | + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", |
| 251 | + disable_tp=use_data_parallel, |
| 252 | + ) |
| 253 | + self.proj = RowParallelLinear( |
| 254 | + input_size=projection_size, |
| 255 | + output_size=embed_dim, |
| 256 | + quant_config=quant_config, |
| 257 | + prefix=f"{prefix}.proj", |
| 258 | + bias=False, |
| 259 | + disable_tp=use_data_parallel, |
| 260 | + ) |
279 | 261 |
|
280 | 262 | # Detect attention implementation. |
281 | 263 | self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) |
@@ -494,41 +476,31 @@ def __init__( |
494 | 476 | ) -> None: |
495 | 477 | super().__init__() |
496 | 478 | self.hidden_size = d_model |
497 | | - if use_data_parallel: |
498 | | - self.proj = ReplicatedLinear( |
499 | | - input_size=self.hidden_size, |
500 | | - output_size=self.hidden_size, |
501 | | - bias=bias, |
502 | | - quant_config=quant_config, |
503 | | - prefix=f"{prefix}.proj", |
504 | | - ) |
505 | | - else: |
506 | | - self.proj = ColumnParallelLinear( |
507 | | - self.hidden_size, |
508 | | - self.hidden_size, |
509 | | - bias=bias, |
510 | | - gather_output=True, |
511 | | - quant_config=quant_config, |
512 | | - prefix=f"{prefix}.proj", |
513 | | - ) |
| 479 | + self.proj = ColumnParallelLinear( |
| 480 | + self.hidden_size, |
| 481 | + self.hidden_size, |
| 482 | + bias=bias, |
| 483 | + gather_output=True, |
| 484 | + quant_config=quant_config, |
| 485 | + prefix=f"{prefix}.proj", |
| 486 | + disable_tp=use_data_parallel, |
| 487 | + ) |
514 | 488 | self.post_projection_norm = nn.LayerNorm(self.hidden_size) |
515 | | - cls_gate_up = (MergedReplicatedLinear |
516 | | - if use_data_parallel else MergedColumnParallelLinear) |
517 | | - self.gate_up_proj = cls_gate_up( |
| 489 | + self.gate_up_proj = MergedColumnParallelLinear( |
518 | 490 | input_size=self.hidden_size, |
519 | 491 | output_sizes=[context_dim] * 2, |
520 | 492 | bias=bias, |
521 | 493 | quant_config=quant_config, |
522 | 494 | prefix=f"{prefix}.gate_up_proj", |
| 495 | + disable_tp=use_data_parallel, |
523 | 496 | ) |
524 | | - cls_down = (ReplicatedLinear |
525 | | - if use_data_parallel else RowParallelLinear) |
526 | | - self.down_proj = cls_down( |
| 497 | + self.down_proj = RowParallelLinear( |
527 | 498 | context_dim, |
528 | 499 | self.hidden_size, |
529 | 500 | bias=bias, |
530 | 501 | quant_config=quant_config, |
531 | 502 | prefix=f"{prefix}.down_proj", |
| 503 | + disable_tp=use_data_parallel, |
532 | 504 | ) |
533 | 505 | self.act_fn = SiluAndMul() |
534 | 506 | self.extra_activation_func = nn.GELU() |
|
0 commit comments