-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer
#27418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,7 @@ | |
| ColumnParallelLinear, | ||
| MergedColumnParallelLinear, | ||
| QKVParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
|
|
@@ -98,7 +99,11 @@ | |
| init_vllm_registered_model, | ||
| maybe_prefix, | ||
| ) | ||
| from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model | ||
| from .vision import ( | ||
| conv3d_to_linear_weight, | ||
| get_vit_attn_backend, | ||
| run_dp_sharded_mrope_vision_model, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -478,18 +483,15 @@ def __init__( | |
| self.hidden_size = hidden_size | ||
|
|
||
| kernel_size = (temporal_patch_size, patch_size, patch_size) | ||
| self.proj = nn.Conv3d( | ||
| in_channels, | ||
| self.proj = ReplicatedLinear( | ||
| in_channels * math.prod(kernel_size), | ||
| hidden_size, | ||
| kernel_size=kernel_size, | ||
| stride=kernel_size, | ||
| bias=True, | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| L, C = x.shape | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code reshapes the input tensor Can you confirm that the input tensor |
||
| x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) | ||
| x = self.proj(x).view(L, self.hidden_size) | ||
| x = self.proj(x) | ||
| return x | ||
|
|
||
|
|
||
|
|
@@ -887,6 +889,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loaded_params: set[str] = set() | ||
|
|
||
| for name, loaded_weight in weights: | ||
| if name.endswith("patch_embed.proj.weight"): | ||
| loaded_weight = conv3d_to_linear_weight(loaded_weight) | ||
|
|
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| # limitations under the License. | ||
| """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" | ||
|
|
||
| import math | ||
| from collections.abc import Callable, Iterable, Mapping, Sequence | ||
| from functools import lru_cache, partial | ||
| from typing import Annotated, Any, Literal, TypeAlias | ||
|
|
@@ -56,6 +57,7 @@ | |
| ColumnParallelLinear, | ||
| MergedColumnParallelLinear, | ||
| QKVParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
|
|
@@ -98,7 +100,11 @@ | |
| init_vllm_registered_model, | ||
| maybe_prefix, | ||
| ) | ||
| from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model | ||
| from .vision import ( | ||
| conv3d_to_linear_weight, | ||
| get_vit_attn_backend, | ||
| run_dp_sharded_mrope_vision_model, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -532,18 +538,15 @@ def __init__( | |
| self.hidden_size = hidden_size | ||
|
|
||
| kernel_size = (temporal_patch_size, patch_size, patch_size) | ||
| self.proj = nn.Conv3d( | ||
| in_channels, | ||
| self.proj = ReplicatedLinear( | ||
| in_channels * math.prod(kernel_size), | ||
| hidden_size, | ||
| kernel_size=kernel_size, | ||
| stride=kernel_size, | ||
| bias=False, | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| L, C = x.shape | ||
| x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) | ||
| x = self.proj(x).view(L, self.hidden_size) | ||
| x = self.proj(x) | ||
| return x | ||
|
Comment on lines
548
to
550
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code reshapes the input tensor Can you confirm that the input tensor |
||
|
|
||
|
|
||
|
|
@@ -950,6 +953,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loaded_params: set[str] = set() | ||
|
|
||
| for name, loaded_weight in weights: | ||
| if name.endswith("patch_embed.proj.weight"): | ||
| loaded_weight = conv3d_to_linear_weight(loaded_weight) | ||
|
|
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| # limitations under the License. | ||
| """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" | ||
|
|
||
| import math | ||
| from collections.abc import Callable, Iterable, Mapping, Sequence | ||
| from functools import partial | ||
| from typing import Annotated, Any, Literal, TypeAlias | ||
|
|
@@ -53,7 +54,11 @@ | |
| from vllm.distributed import utils as dist_utils | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.activation import QuickGELU | ||
| from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.rotary_embedding.common import ( | ||
| dispatch_rotary_emb_function, | ||
|
|
@@ -100,7 +105,11 @@ | |
| init_vllm_registered_model, | ||
| maybe_prefix, | ||
| ) | ||
| from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model | ||
| from .vision import ( | ||
| conv3d_to_linear_weight, | ||
| get_vit_attn_backend, | ||
| run_dp_sharded_mrope_vision_model, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -561,18 +570,15 @@ def __init__( | |
| self.embed_dim = embed_dim | ||
|
|
||
| kernel_size = (temporal_patch_size, patch_size, patch_size) | ||
| self.proj = nn.Conv3d( | ||
| in_channels, | ||
| self.proj = ReplicatedLinear( | ||
| in_channels * math.prod(kernel_size), | ||
| embed_dim, | ||
| kernel_size=kernel_size, | ||
| stride=kernel_size, | ||
| bias=False, | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| L, C = x.shape | ||
| x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) | ||
| x = self.proj(x).view(L, self.embed_dim) | ||
| x = self.proj(x) | ||
| return x | ||
|
Comment on lines
580
to
582
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code reshapes the input tensor Can you confirm that the input tensor |
||
|
|
||
|
|
||
|
|
@@ -835,6 +841,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loaded_params: set[str] = set() | ||
|
|
||
| for name, loaded_weight in weights: | ||
| if name.endswith("patch_embed.proj.weight"): | ||
| loaded_weight = conv3d_to_linear_weight(loaded_weight) | ||
|
|
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| # limitations under the License. | ||
| """Inference-only Qwen3-Omni-Moe model (thinker part).""" | ||
|
|
||
| import math | ||
| from collections.abc import Callable, Iterable, Mapping, Sequence | ||
| from functools import partial | ||
| from typing import Any | ||
|
|
@@ -53,7 +54,11 @@ | |
| from vllm.distributed import get_pp_group | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY | ||
| from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead | ||
|
|
@@ -98,7 +103,11 @@ | |
| _merge_multimodal_embeddings, | ||
| maybe_prefix, | ||
| ) | ||
| from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend | ||
| from .vision import ( | ||
| conv3d_to_linear_weight, | ||
| get_llm_pos_ids_for_vision, | ||
| get_vit_attn_backend, | ||
| ) | ||
|
|
||
| try: | ||
| import flash_attn | ||
|
|
@@ -131,18 +140,16 @@ def __init__( | |
| self.hidden_size = hidden_size | ||
|
|
||
| kernel_size = (temporal_patch_size, patch_size, patch_size) | ||
| self.proj = nn.Conv3d( | ||
| in_channels, | ||
| self.proj = ReplicatedLinear( | ||
| in_channels * math.prod(kernel_size), | ||
| hidden_size, | ||
| kernel_size=kernel_size, | ||
| stride=kernel_size, | ||
| bias=True, | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| L, C = x.shape | ||
| x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) | ||
| x = self.proj(x).view(L, self.hidden_size) | ||
| x = self.proj(x) | ||
|
Comment on lines
142
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code reshapes the input tensor Can you confirm that the input tensor |
||
| return x | ||
|
|
||
|
|
||
|
|
@@ -559,6 +566,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loaded_params: set[str] = set() | ||
|
|
||
| for name, loaded_weight in weights: | ||
| if name.endswith("patch_embed.proj.weight"): | ||
| loaded_weight = conv3d_to_linear_weight(loaded_weight) | ||
|
|
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| # limitations under the License. | ||
| """Inference-only Qwen3VL model compatible with HuggingFace weights.""" | ||
|
|
||
| import math | ||
| from collections.abc import Callable, Iterable, Mapping, Sequence | ||
| from functools import partial | ||
| from itertools import islice | ||
|
|
@@ -56,7 +57,11 @@ | |
| from vllm.distributed import get_pp_group | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY | ||
| from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead | ||
|
|
@@ -107,7 +112,11 @@ | |
| _merge_multimodal_embeddings, | ||
| maybe_prefix, | ||
| ) | ||
| from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model | ||
| from .vision import ( | ||
| conv3d_to_linear_weight, | ||
| get_vit_attn_backend, | ||
| run_dp_sharded_mrope_vision_model, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -129,18 +138,15 @@ def __init__( | |
| self.hidden_size = hidden_size | ||
|
|
||
| kernel_size = (temporal_patch_size, patch_size, patch_size) | ||
| self.proj = nn.Conv3d( | ||
| in_channels, | ||
| self.proj = ReplicatedLinear( | ||
| in_channels * math.prod(kernel_size), | ||
| hidden_size, | ||
| kernel_size=kernel_size, | ||
| stride=kernel_size, | ||
| bias=True, | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| L, C = x.shape | ||
| x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) | ||
| x = self.proj(x).view(L, self.hidden_size) | ||
| x = self.proj(x) | ||
| return x | ||
|
Comment on lines
+149
to
150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original code reshapes the input tensor Can you confirm that the input tensor |
||
|
|
||
|
|
||
|
|
@@ -576,6 +582,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loaded_params: set[str] = set() | ||
|
|
||
| for name, loaded_weight in weights: | ||
| if name.endswith("patch_embed.proj.weight"): | ||
| loaded_weight = conv3d_to_linear_weight(loaded_weight) | ||
|
|
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision( | |
| llm_pos_ids_list.append(_llm_pos_ids + start_idx) | ||
| llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) | ||
| return llm_pos_ids | ||
|
|
||
|
|
||
| # Due to a performance regression with Conv3D in PyTorch2.9, we reshape | ||
| # Conv3D weights to Linear weights for better performance. | ||
| # See: https://github.com/vllm-project/vllm/issues/27406 | ||
| # and https://github.com/pytorch/pytorch/issues/166122 | ||
| # FIXME(Isotr0py): Revert the PR introduces this workaround | ||
| # (https://github.com/vllm-project/vllm/pull/27418), | ||
| # once the performance issue is resolved in PyTorch. | ||
| def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment why we need this? |
||
| """ | ||
| Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride. | ||
| """ | ||
| out_channels, in_channels, kt, kh, kw = conv3d_weight.shape | ||
| linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw) | ||
| return linear_weight | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code reshapes the input tensor
xbefore applying the convolutional layer and then reshapes the output. With the replacement ofnn.Conv3dbyReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensorxis now directly compatible with theReplicatedLinearlayer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feedingxintoself.projwithout proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the inputxnow has the correct shape expected byReplicatedLinearto avoid breaking the model's functionality.Can you confirm that the input tensor
xis correctly preprocessed to match the expected input shape of theReplicatedLinearlayer? If not, this could lead to a critical error.