-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Model] Consolidate pooler implementations #20927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DarkLight1337
merged 16 commits into
vllm-project:main
from
DarkLight1337:consolidate-poolers
Jul 16, 2025
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
7d2bd3b
[Model] Consolidate pooler implementations
DarkLight1337 7d10a82
Rework
DarkLight1337 fcfda73
Fix pre-commit
DarkLight1337 96c548a
Fix initialization
DarkLight1337 e78ad95
Simplify
DarkLight1337 d44795d
Rename
DarkLight1337 39754c9
More abstraction
DarkLight1337 8b5f995
Optimize
DarkLight1337 db6d4f7
Fix jamba weight loading
DarkLight1337 9e7d448
Simplify
DarkLight1337 d510d9c
Fix attribute access
DarkLight1337 f74f21f
Fix adapter
DarkLight1337 a0a9bac
Handle different tensor size
DarkLight1337 eac90db
Fix
DarkLight1337 f144b08
Simplify
DarkLight1337 64fae14
Fix accuracy issue
DarkLight1337 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,8 @@ | |
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType, | ||
SimplePooler) | ||
from vllm.model_executor.layers.quantization import QuantizationConfig | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) | ||
|
@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM): | |
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__(vllm_config=vllm_config, prefix=prefix) | ||
|
||
config = vllm_config.model_config.hf_config | ||
num_labels: int = config.num_labels | ||
score_bias: bool = getattr(config, 'score_bias', False) | ||
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) | ||
|
||
# TODO: The original reward weights have float32 accuracy data, we | ||
# would like to load them in fp32 to get that extra precision. | ||
# Currently weight_loader passes the weight which is already in bf16 | ||
self.score = nn.Linear( | ||
config.hidden_size, | ||
num_labels, | ||
bias=score_bias, | ||
dtype=torch.float32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model uses |
||
) | ||
|
||
pooler_config = vllm_config.model_config.pooler_config | ||
self._pooler = Pooler.from_config_with_defaults( | ||
assert pooler_config is not None | ||
|
||
pooler = SimplePooler.from_config_with_defaults( | ||
pooler_config, | ||
pooling_type=PoolingType.LAST, | ||
normalize=False, | ||
softmax=False) | ||
softmax=False, | ||
) | ||
|
||
self._pooler = ClassifierPooler( | ||
vllm_config.model_config, | ||
pooling=pooler.pooling, | ||
classifier=self.score, | ||
act_fn=pooler.head.activation, | ||
) | ||
|
||
def pooler( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
hidden_states = hidden_states.float() | ||
logits = self.score(hidden_states) | ||
return self._pooler(logits, pooling_metadata) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
# TODO: The reward weights themselves have float32 accuracy data, we | ||
# would like to load them in fp32 to get that extra precision. | ||
super().load_weights(weights) | ||
self.score = self.score.float() | ||
return self._pooler(hidden_states, pooling_metadata) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks hacky. I'm planning to require models to define
pooler
as aBasePooler
instance in the next PR so we can directly inspectmodel.pooler
to get this information