Skip to content

Commit 4c8dd12

Browse files
authored
[Misc] Add qwen2.5-vl BNB support (vllm-project#12944)
1 parent 256a2d2 commit 4c8dd12

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

vllm/model_executor/models/qwen2_5_vl.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
from vllm.attention import AttentionMetadata
4242
from vllm.config import VllmConfig
43-
from vllm.distributed import parallel_state
43+
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
4444
from vllm.distributed import utils as dist_utils
4545
from vllm.logger import init_logger
4646
from vllm.model_executor import SamplingMetadata
@@ -207,11 +207,12 @@ def __init__(
207207
) -> None:
208208
super().__init__()
209209
# Per attention head and per partition values.
210-
world_size = parallel_state.get_tensor_model_parallel_world_size()
210+
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
211+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
211212
self.hidden_size_per_attention_head = dist_utils.divide(
212213
projection_size, num_heads)
213214
self.num_attention_heads_per_partition = dist_utils.divide(
214-
num_heads, world_size)
215+
num_heads, self.tp_size)
215216

216217
self.qkv = ColumnParallelLinear(input_size=embed_dim,
217218
output_size=3 * projection_size,
@@ -231,6 +232,29 @@ def __init__(
231232
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
232233
)
233234

235+
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
236+
# [s, b, 3 * head * head_dim]
237+
seq_len, bs, _ = qkv.shape
238+
if self.tp_size > 1:
239+
qkv = tensor_model_parallel_all_gather(qkv)
240+
241+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
242+
q, k, v = qkv.chunk(3, dim=2)
243+
244+
# 3 * [s, b, head * head_dim]
245+
if self.tp_size > 1:
246+
splitter = partial(dist_utils.split_tensor_along_last_dim,
247+
num_partitions=self.tp_size)
248+
q = splitter(q)[self.tp_rank]
249+
k = splitter(k)[self.tp_rank]
250+
v = splitter(v)[self.tp_rank]
251+
252+
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
253+
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
254+
self.hidden_size_per_attention_head)
255+
q, k, v = (x.view(*new_shape) for x in (q, k, v))
256+
return q, k, v
257+
234258
def forward(
235259
self,
236260
x: torch.Tensor,
@@ -240,15 +264,8 @@ def forward(
240264
# [s, b, c] --> [s, b, head * 3 * head_dim]
241265
x, _ = self.qkv(x)
242266

243-
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
244-
new_x_shape = x.size()[:-1] + (
245-
self.num_attention_heads_per_partition,
246-
3 * self.hidden_size_per_attention_head,
247-
)
248-
x = x.view(*new_x_shape)
249-
250-
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
251-
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
267+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
268+
q, k, v = self.split_qkv(x)
252269
batch_size = q.shape[1]
253270

254271
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
@@ -665,24 +682,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
665682
weight_loader(param, loaded_weight, shard_id)
666683
break
667684
else:
668-
if name.endswith("qkv.weight"):
669-
visual_num_heads = self.num_heads
670-
visual_embed_dim = self.hidden_size
671-
head_size = visual_embed_dim // visual_num_heads
672-
loaded_weight = loaded_weight.view(3, visual_num_heads,
673-
head_size,
674-
visual_embed_dim)
675-
loaded_weight = loaded_weight.transpose(0, 1)
676-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
677-
elif name.endswith("qkv.bias"):
678-
visual_num_heads = self.num_heads
679-
visual_embed_dim = self.hidden_size
680-
head_size = visual_embed_dim // visual_num_heads
681-
loaded_weight = loaded_weight.view(3, visual_num_heads,
682-
head_size)
683-
loaded_weight = loaded_weight.transpose(0, 1)
684-
loaded_weight = loaded_weight.reshape(-1)
685-
686685
param = params_dict[name]
687686
weight_loader = getattr(param, "weight_loader",
688687
default_weight_loader)

0 commit comments

Comments
 (0)