|
51 | 51 | from vllm.logger import init_logger |
52 | 52 | from vllm.model_executor import SamplingMetadata |
53 | 53 | from vllm.model_executor.layers.layernorm import RMSNorm |
| 54 | +# yapf: disable |
54 | 55 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
55 | 56 | MergedColumnParallelLinear, |
| 57 | + MergedReplicatedLinear, |
56 | 58 | QKVParallelLinear, |
| 59 | + ReplicatedLinear, |
57 | 60 | RowParallelLinear) |
| 61 | +# yapf: enable |
58 | 62 | from vllm.model_executor.layers.quantization import QuantizationConfig |
59 | 63 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
60 | 64 | from vllm.model_executor.models.module_mapping import MultiModelKeys |
@@ -170,22 +174,20 @@ def __init__( |
170 | 174 | use_data_parallel: bool = False, |
171 | 175 | ): |
172 | 176 | super().__init__() |
173 | | - self.gate_up_proj = MergedColumnParallelLinear( |
174 | | - input_size=in_features, |
175 | | - output_sizes=[hidden_features] * 2, |
176 | | - bias=bias, |
177 | | - quant_config=quant_config, |
178 | | - prefix=f"{prefix}.gate_up_proj", |
179 | | - disable_tp=use_data_parallel, |
180 | | - ) |
181 | | - self.down_proj = RowParallelLinear( |
182 | | - hidden_features, |
183 | | - in_features, |
184 | | - bias=bias, |
185 | | - quant_config=quant_config, |
186 | | - prefix=f"{prefix}.down_proj", |
187 | | - disable_tp=use_data_parallel, |
188 | | - ) |
| 177 | + cls_gate_up = (MergedReplicatedLinear |
| 178 | + if use_data_parallel else MergedColumnParallelLinear) |
| 179 | + self.gate_up_proj = cls_gate_up(input_size=in_features, |
| 180 | + output_sizes=[hidden_features] * 2, |
| 181 | + bias=bias, |
| 182 | + quant_config=quant_config, |
| 183 | + prefix=f"{prefix}.gate_up_proj") |
| 184 | + cls_down = (ReplicatedLinear |
| 185 | + if use_data_parallel else RowParallelLinear) |
| 186 | + self.down_proj = cls_down(hidden_features, |
| 187 | + in_features, |
| 188 | + bias=bias, |
| 189 | + quant_config=quant_config, |
| 190 | + prefix=f"{prefix}.down_proj") |
189 | 191 | self.act_fn = SiluAndMul() |
190 | 192 |
|
191 | 193 | def forward(self, x: torch.Tensor): |
@@ -232,32 +234,48 @@ def __init__( |
232 | 234 | # Per attention head and per partition values. |
233 | 235 | self.tp_size = (1 if use_data_parallel else |
234 | 236 | get_tensor_model_parallel_world_size()) |
235 | | - self.tp_rank = (0 if use_data_parallel else |
236 | | - parallel_state.get_tensor_model_parallel_rank()) |
| 237 | + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() |
237 | 238 | self.hidden_size_per_attention_head = dist_utils.divide( |
238 | 239 | projection_size, num_heads) |
239 | 240 | self.num_attention_heads_per_partition = dist_utils.divide( |
240 | 241 | num_heads, self.tp_size) |
241 | 242 |
|
242 | | - self.qkv = QKVParallelLinear( |
243 | | - hidden_size=embed_dim, |
244 | | - head_size=self.hidden_size_per_attention_head, |
245 | | - total_num_heads=num_heads, |
246 | | - total_num_kv_heads=num_heads, |
247 | | - bias=False, |
248 | | - quant_config=quant_config, |
249 | | - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
250 | | - prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", |
251 | | - disable_tp=use_data_parallel, |
252 | | - ) |
253 | | - self.proj = RowParallelLinear( |
254 | | - input_size=projection_size, |
255 | | - output_size=embed_dim, |
256 | | - quant_config=quant_config, |
257 | | - prefix=f"{prefix}.proj", |
258 | | - bias=False, |
259 | | - disable_tp=use_data_parallel, |
260 | | - ) |
| 243 | + if use_data_parallel: |
| 244 | + self.qkv = ReplicatedLinear( |
| 245 | + input_size=embed_dim, |
| 246 | + output_size=3 * projection_size, |
| 247 | + bias=False, |
| 248 | + quant_config=quant_config, |
| 249 | + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
| 250 | + prefix=f"{prefix}.qkv_proj" |
| 251 | + if quant_config else f"{prefix}.qkv", |
| 252 | + ) |
| 253 | + self.proj = ReplicatedLinear( |
| 254 | + input_size=projection_size, |
| 255 | + output_size=embed_dim, |
| 256 | + quant_config=quant_config, |
| 257 | + prefix=f"{prefix}.proj", |
| 258 | + bias=False, |
| 259 | + ) |
| 260 | + else: |
| 261 | + self.qkv = QKVParallelLinear( |
| 262 | + hidden_size=embed_dim, |
| 263 | + head_size=self.hidden_size_per_attention_head, |
| 264 | + total_num_heads=num_heads, |
| 265 | + total_num_kv_heads=num_heads, |
| 266 | + bias=False, |
| 267 | + quant_config=quant_config, |
| 268 | + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
| 269 | + prefix=f"{prefix}.qkv_proj" |
| 270 | + if quant_config else f"{prefix}.qkv", |
| 271 | + ) |
| 272 | + self.proj = RowParallelLinear( |
| 273 | + input_size=projection_size, |
| 274 | + output_size=embed_dim, |
| 275 | + quant_config=quant_config, |
| 276 | + prefix=f"{prefix}.proj", |
| 277 | + bias=False, |
| 278 | + ) |
261 | 279 |
|
262 | 280 | # Detect attention implementation. |
263 | 281 | self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) |
@@ -476,31 +494,41 @@ def __init__( |
476 | 494 | ) -> None: |
477 | 495 | super().__init__() |
478 | 496 | self.hidden_size = d_model |
479 | | - self.proj = ColumnParallelLinear( |
480 | | - self.hidden_size, |
481 | | - self.hidden_size, |
482 | | - bias=bias, |
483 | | - gather_output=True, |
484 | | - quant_config=quant_config, |
485 | | - prefix=f"{prefix}.proj", |
486 | | - disable_tp=use_data_parallel, |
487 | | - ) |
| 497 | + if use_data_parallel: |
| 498 | + self.proj = ReplicatedLinear( |
| 499 | + input_size=self.hidden_size, |
| 500 | + output_size=self.hidden_size, |
| 501 | + bias=bias, |
| 502 | + quant_config=quant_config, |
| 503 | + prefix=f"{prefix}.proj", |
| 504 | + ) |
| 505 | + else: |
| 506 | + self.proj = ColumnParallelLinear( |
| 507 | + self.hidden_size, |
| 508 | + self.hidden_size, |
| 509 | + bias=bias, |
| 510 | + gather_output=True, |
| 511 | + quant_config=quant_config, |
| 512 | + prefix=f"{prefix}.proj", |
| 513 | + ) |
488 | 514 | self.post_projection_norm = nn.LayerNorm(self.hidden_size) |
489 | | - self.gate_up_proj = MergedColumnParallelLinear( |
| 515 | + cls_gate_up = (MergedReplicatedLinear |
| 516 | + if use_data_parallel else MergedColumnParallelLinear) |
| 517 | + self.gate_up_proj = cls_gate_up( |
490 | 518 | input_size=self.hidden_size, |
491 | 519 | output_sizes=[context_dim] * 2, |
492 | 520 | bias=bias, |
493 | 521 | quant_config=quant_config, |
494 | 522 | prefix=f"{prefix}.gate_up_proj", |
495 | | - disable_tp=use_data_parallel, |
496 | 523 | ) |
497 | | - self.down_proj = RowParallelLinear( |
| 524 | + cls_down = (ReplicatedLinear |
| 525 | + if use_data_parallel else RowParallelLinear) |
| 526 | + self.down_proj = cls_down( |
498 | 527 | context_dim, |
499 | 528 | self.hidden_size, |
500 | 529 | bias=bias, |
501 | 530 | quant_config=quant_config, |
502 | 531 | prefix=f"{prefix}.down_proj", |
503 | | - disable_tp=use_data_parallel, |
504 | 532 | ) |
505 | 533 | self.act_fn = SiluAndMul() |
506 | 534 | self.extra_activation_func = nn.GELU() |
|
0 commit comments