Skip to content
681 changes: 434 additions & 247 deletions vllm/model_executor/layers/pooler.py

Large diffs are not rendered by default.

99 changes: 48 additions & 51 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,27 @@ def __init__(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

self.vllm_config = vllm_config

# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)

# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._init_pooler(vllm_config, prefix=prefix)

def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)

def pooler(
self,
Expand Down Expand Up @@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:

# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
Expand All @@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):

def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)

self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
self.score = RowParallelLinear(
config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=False,
params_dtype=torch.float32,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "score"),
)

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)

self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)

def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x

def forward(
self,
Expand All @@ -222,27 +239,7 @@ def pooler(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:

def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits

if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)

pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
Expand Down
25 changes: 16 additions & 9 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn
Expand All @@ -18,7 +18,7 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType)
PoolingMethod, PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -83,14 +83,18 @@ class BertPooler(nn.Module):

def __init__(self, config: BertConfig):
super().__init__()

self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output

Expand Down Expand Up @@ -466,8 +470,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier, self.bert.pooler)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
Expand Down Expand Up @@ -49,7 +49,7 @@ def tokens_to_ids(tokens: list[str]) -> array:
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])

self.head = PoolerHead(normalize=True, softmax=False)
self.head = PoolerHead(PoolerNormalize())

def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules())
type(module).__name__ == "StepPooler" for module in model.modules())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks hacky. I'm planning to require models to define pooler as a BasePooler instance in the next PR so we can directly inspect model.pooler to get this information



class SupportsQuant:
Expand Down
39 changes: 26 additions & 13 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
SimplePooler)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
Expand Down Expand Up @@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)

# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self.score = nn.Linear(
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model uses Qwen2ForCausalLM architecture so it should not be related

)

pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
assert pooler_config is not None

pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False)
softmax=False,
)

self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float()
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()
return self._pooler(hidden_states, pooling_metadata)
33 changes: 18 additions & 15 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn
Expand All @@ -13,7 +13,8 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -252,10 +253,13 @@ def forward(
return norm_outputs


class ModernBertPooler(nn.Module):
class ModernBertPooler(BasePooler):

def __init__(self, config: ModernBertConfig):
super().__init__()

pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
Expand All @@ -264,15 +268,12 @@ def __init__(self, config: ModernBertConfig):
eps=config.norm_eps,
bias=config.norm_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
if self.pooling_type == "mean":
pooled_output = pooled_output.mean(dim=0, keepdim=False)
elif self.pooling_type == "cls":
pooled_output = pooled_output[0, :]
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output

Expand All @@ -287,9 +288,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier,
ModernBertPooler(config))
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=ModernBertPooler(config),
classifier=self.classifier,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

Expand Down
13 changes: 8 additions & 5 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -105,8 +105,8 @@ def __init__(self, config: RobertaConfig):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
Expand Down Expand Up @@ -183,8 +183,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)

self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
Expand Down
Loading