40
40
41
41
from vllm .attention import AttentionMetadata
42
42
from vllm .config import VllmConfig
43
- from vllm .distributed import parallel_state
43
+ from vllm .distributed import parallel_state , tensor_model_parallel_all_gather
44
44
from vllm .distributed import utils as dist_utils
45
45
from vllm .logger import init_logger
46
46
from vllm .model_executor import SamplingMetadata
@@ -207,11 +207,12 @@ def __init__(
207
207
) -> None :
208
208
super ().__init__ ()
209
209
# Per attention head and per partition values.
210
- world_size = parallel_state .get_tensor_model_parallel_world_size ()
210
+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
211
+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
211
212
self .hidden_size_per_attention_head = dist_utils .divide (
212
213
projection_size , num_heads )
213
214
self .num_attention_heads_per_partition = dist_utils .divide (
214
- num_heads , world_size )
215
+ num_heads , self . tp_size )
215
216
216
217
self .qkv = ColumnParallelLinear (input_size = embed_dim ,
217
218
output_size = 3 * projection_size ,
@@ -231,6 +232,29 @@ def __init__(
231
232
f"Qwen2.5-VL does not support { self .attn_backend } backend now."
232
233
)
233
234
235
+ def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
236
+ # [s, b, 3 * head * head_dim]
237
+ seq_len , bs , _ = qkv .shape
238
+ if self .tp_size > 1 :
239
+ qkv = tensor_model_parallel_all_gather (qkv )
240
+
241
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
242
+ q , k , v = qkv .chunk (3 , dim = 2 )
243
+
244
+ # 3 * [s, b, head * head_dim]
245
+ if self .tp_size > 1 :
246
+ splitter = partial (dist_utils .split_tensor_along_last_dim ,
247
+ num_partitions = self .tp_size )
248
+ q = splitter (q )[self .tp_rank ]
249
+ k = splitter (k )[self .tp_rank ]
250
+ v = splitter (v )[self .tp_rank ]
251
+
252
+ # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
253
+ new_shape = (seq_len , bs , self .num_attention_heads_per_partition ,
254
+ self .hidden_size_per_attention_head )
255
+ q , k , v = (x .view (* new_shape ) for x in (q , k , v ))
256
+ return q , k , v
257
+
234
258
def forward (
235
259
self ,
236
260
x : torch .Tensor ,
@@ -240,15 +264,8 @@ def forward(
240
264
# [s, b, c] --> [s, b, head * 3 * head_dim]
241
265
x , _ = self .qkv (x )
242
266
243
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
244
- new_x_shape = x .size ()[:- 1 ] + (
245
- self .num_attention_heads_per_partition ,
246
- 3 * self .hidden_size_per_attention_head ,
247
- )
248
- x = x .view (* new_x_shape )
249
-
250
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
251
- q , k , v = dist_utils .split_tensor_along_last_dim (x , 3 )
267
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
268
+ q , k , v = self .split_qkv (x )
252
269
batch_size = q .shape [1 ]
253
270
254
271
q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
@@ -665,24 +682,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
665
682
weight_loader (param , loaded_weight , shard_id )
666
683
break
667
684
else :
668
- if name .endswith ("qkv.weight" ):
669
- visual_num_heads = self .num_heads
670
- visual_embed_dim = self .hidden_size
671
- head_size = visual_embed_dim // visual_num_heads
672
- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
673
- head_size ,
674
- visual_embed_dim )
675
- loaded_weight = loaded_weight .transpose (0 , 1 )
676
- loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
677
- elif name .endswith ("qkv.bias" ):
678
- visual_num_heads = self .num_heads
679
- visual_embed_dim = self .hidden_size
680
- head_size = visual_embed_dim // visual_num_heads
681
- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
682
- head_size )
683
- loaded_weight = loaded_weight .transpose (0 , 1 )
684
- loaded_weight = loaded_weight .reshape (- 1 )
685
-
686
685
param = params_dict [name ]
687
686
weight_loader = getattr (param , "weight_loader" ,
688
687
default_weight_loader )
0 commit comments