Skip to content

Commit c66c7f8

Browse files
authored
[Bugfix] Fix PaliGemma MMP (vllm-project#6930)
1 parent 6e063ea commit c66c7f8

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

vllm/model_executor/models/paligemma.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.config import CacheConfig, MultiModalConfig
1010
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
1111
from vllm.logger import init_logger
12-
from vllm.model_executor.layers.linear import ColumnParallelLinear
1312
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1413
from vllm.model_executor.layers.quantization.base_config import (
1514
QuantizationConfig)
@@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
133132
def __init__(self, vision_hidden_size: int, projection_dim: int):
134133
super().__init__()
135134

136-
self.linear = ColumnParallelLinear(vision_hidden_size,
137-
projection_dim,
138-
bias=True)
135+
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
139136

140137
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
141-
hidden_states, _ = self.linear(image_features)
138+
hidden_states = self.linear(image_features)
142139
return hidden_states
143140

144141

0 commit comments

Comments
 (0)