|
25 | 25 | # limitations under the License. |
26 | 26 | """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" |
27 | 27 |
|
| 28 | +import math |
28 | 29 | from collections.abc import Callable, Iterable, Mapping, Sequence |
29 | 30 | from functools import partial |
30 | 31 | from typing import Annotated, Any, Literal, TypeAlias |
|
53 | 54 | from vllm.distributed import utils as dist_utils |
54 | 55 | from vllm.logger import init_logger |
55 | 56 | from vllm.model_executor.layers.activation import QuickGELU |
56 | | -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear |
| 57 | +from vllm.model_executor.layers.linear import ( |
| 58 | + ColumnParallelLinear, |
| 59 | + ReplicatedLinear, |
| 60 | + RowParallelLinear, |
| 61 | +) |
57 | 62 | from vllm.model_executor.layers.quantization import QuantizationConfig |
58 | 63 | from vllm.model_executor.layers.rotary_embedding.common import ( |
59 | 64 | dispatch_rotary_emb_function, |
|
100 | 105 | init_vllm_registered_model, |
101 | 106 | maybe_prefix, |
102 | 107 | ) |
103 | | -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model |
| 108 | +from .vision import ( |
| 109 | + conv3d_to_linear_weight, |
| 110 | + get_vit_attn_backend, |
| 111 | + run_dp_sharded_mrope_vision_model, |
| 112 | +) |
104 | 113 |
|
105 | 114 | logger = init_logger(__name__) |
106 | 115 |
|
@@ -561,18 +570,15 @@ def __init__( |
561 | 570 | self.embed_dim = embed_dim |
562 | 571 |
|
563 | 572 | kernel_size = (temporal_patch_size, patch_size, patch_size) |
564 | | - self.proj = nn.Conv3d( |
565 | | - in_channels, |
| 573 | + self.proj = ReplicatedLinear( |
| 574 | + in_channels * math.prod(kernel_size), |
566 | 575 | embed_dim, |
567 | | - kernel_size=kernel_size, |
568 | | - stride=kernel_size, |
569 | 576 | bias=False, |
| 577 | + return_bias=False, |
570 | 578 | ) |
571 | 579 |
|
572 | 580 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
573 | | - L, C = x.shape |
574 | | - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) |
575 | | - x = self.proj(x).view(L, self.embed_dim) |
| 581 | + x = self.proj(x) |
576 | 582 | return x |
577 | 583 |
|
578 | 584 |
|
@@ -835,6 +841,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
835 | 841 | loaded_params: set[str] = set() |
836 | 842 |
|
837 | 843 | for name, loaded_weight in weights: |
| 844 | + if name.endswith("patch_embed.proj.weight"): |
| 845 | + loaded_weight = conv3d_to_linear_weight(loaded_weight) |
| 846 | + |
838 | 847 | for param_name, weight_name, shard_id in stacked_params_mapping: |
839 | 848 | if weight_name not in name: |
840 | 849 | continue |
|
0 commit comments