Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -98,7 +99,11 @@
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -478,18 +483,15 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)

Comment on lines +490 to 492
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x now has the correct shape expected by ReplicatedLinear to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.

def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x now has the correct shape expected by ReplicatedLinear to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.

x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x


Expand Down Expand Up @@ -887,6 +889,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
22 changes: 14 additions & 8 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
Expand Down Expand Up @@ -56,6 +57,7 @@
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -98,7 +100,11 @@
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -532,18 +538,15 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
Comment on lines 548 to 550
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x now has the correct shape expected by ReplicatedLinear to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.



Expand Down Expand Up @@ -950,6 +953,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
27 changes: 18 additions & 9 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
Expand Down Expand Up @@ -53,7 +54,11 @@
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
Expand Down Expand Up @@ -100,7 +105,11 @@
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -561,18 +570,15 @@ def __init__(
self.embed_dim = embed_dim

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
x = self.proj(x)
return x
Comment on lines 580 to 582
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x now has the correct shape expected by ReplicatedLinear to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.



Expand Down Expand Up @@ -835,6 +841,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
26 changes: 18 additions & 8 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Any
Expand Down Expand Up @@ -53,7 +54,11 @@
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
Expand Down Expand Up @@ -98,7 +103,11 @@
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)

try:
import flash_attn
Expand Down Expand Up @@ -131,18 +140,16 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
Comment on lines 142 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.

return x


Expand Down Expand Up @@ -559,6 +566,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
27 changes: 18 additions & 9 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
Expand Down Expand Up @@ -56,7 +57,11 @@
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
Expand Down Expand Up @@ -107,7 +112,11 @@
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)

logger = init_logger(__name__)

Expand All @@ -129,18 +138,15 @@ def __init__(
self.hidden_size = hidden_size

kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
Comment on lines +149 to 150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The original code reshapes the input tensor x before applying the convolutional layer and then reshapes the output. With the replacement of nn.Conv3d by ReplicatedLinear, these reshaping operations are no longer necessary and have been removed. However, it's crucial to ensure that the input tensor x is now directly compatible with the ReplicatedLinear layer's expected input shape. This change might introduce a critical issue if the input shape is not correctly adapted to the linear layer, potentially leading to incorrect computations or errors. The original code's reshaping operations might have been essential for aligning the input with the convolutional layer's expected format. Directly feeding x into self.proj without proper reshaping could lead to a mismatch in dimensions, causing the linear layer to perform unintended operations or raise exceptions. It's imperative to verify that the input x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer to avoid breaking the model's functionality.

Can you confirm that the input tensor x is correctly preprocessed to match the expected input shape of the ReplicatedLinear layer? If not, this could lead to a critical error.



Expand Down Expand Up @@ -576,6 +582,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()

for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids


# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment why we need this?

"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight
Loading