Skip to content

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Jul 14, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Use ClassifierPooler in as_seq_cls_model and JambaForSequenceClassification. cc @noooop @maxdebayser

cc @yecohn @tomeras91 can you help check whether the output of Jamba-reward remains the same? There are no tests that check the correctness of this model.

Test Plan

Test Result

(Optional) Documentation Update

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@DarkLight1337 DarkLight1337 requested a review from Isotr0py July 14, 2025 15:17
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @DarkLight1337, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request undertakes a significant refactoring of the pooling and classification mechanisms within the model executor. The primary goal is to consolidate and standardize the implementation of pooling layers, especially for sequence classification tasks, by introducing and integrating a unified ClassifierPooler. This change aims to enhance code consistency, modularity, and maintainability across various models that perform classification.

Highlights

  • Consolidated Pooling Logic: The ClassifierPooler is now the central component for handling pooling, applying the classifier, and managing activation functions for sequence classification models. This unifies the approach across different model types.
  • Refactored Pooler Class: The base Pooler class has been refactored to extract its activation logic into a dedicated activation method, improving modularity and reusability.
  • Standardized as_seq_cls_model: Models utilizing the as_seq_cls_model decorator now instantiate and use the ClassifierPooler for their classification head, simplifying their internal pooler method and removing redundant logic.
  • Jamba Model Integration: The JambaForSequenceClassification model has been updated to leverage the new ClassifierPooler for its reward head, streamlining its implementation and introducing a custom weight loader to ensure float32 precision for its score weights.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the pooler implementations to use a consolidated ClassifierPooler. I've found a critical issue in both adapters.py and jamba.py where the new ClassifierPooler is initialized without a pooler instance, causing the pooling step to be skipped. I've provided suggestions to fix this. For jamba.py, this fix also changes the order of operations from pool(classify(h)) to the more standard classify(pool(h)), which will affect the output as you noted in the PR description.

Comment on lines 212 to 220
def _init_pooler(self, vllm_config: "VllmConfig"):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

return ClassifierPooler(
vllm_config.model_config,
self.score,
act_fn=super()._init_pooler(vllm_config).head.activation,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The ClassifierPooler is initialized without a pooler instance. This will cause the pooling step (extract_states) to be skipped inside ClassifierPooler.forward, as self.pooler will be None. This is a correctness issue as the classifier will be applied on unpooled hidden states, which is likely not the intended behavior.

To fix this, you should create the inner_pooler first and pass it to the ClassifierPooler.

Suggested change
def _init_pooler(self, vllm_config: "VllmConfig"):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
return ClassifierPooler(
vllm_config.model_config,
self.score,
act_fn=super()._init_pooler(vllm_config).head.activation,
)
def _init_pooler(self, vllm_config: "VllmConfig"):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
inner_pooler = super()._init_pooler(vllm_config)
return ClassifierPooler(
vllm_config.model_config,
self.score,
act_fn=inner_pooler.head.activation,
)

Comment on lines 594 to 597
self._pooler = ClassifierPooler(
vllm_config.model_config,
classifier=self.score,
act_fn=inner_pooler.head.activation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The ClassifierPooler is initialized without a pooler instance. This will cause the pooling step to be skipped, which is a change from the previous behavior and likely incorrect. The classifier will be applied on unpooled hidden states.

Additionally, the original implementation for Jamba was pool(classify(hidden_states)), while the standard approach (and what ClassifierPooler implements with the suggested fix) is classify(pool(hidden_states)). This change in logic will affect the model's output, which you've noted in the PR description. Please confirm if this logic change is intended.

To fix the missing pooling step, you should pass the inner_pooler to the ClassifierPooler.

Suggested change
self._pooler = ClassifierPooler(
vllm_config.model_config,
classifier=self.score,
act_fn=inner_pooler.head.activation,
inner_pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
classifier=self.score,
act_fn=inner_pooler.head.activation,
)

@DarkLight1337 DarkLight1337 marked this pull request as draft July 14, 2025 15:30
@DarkLight1337
Copy link
Member Author

Let me fix some bugs locally first

@noooop
Copy link
Contributor

noooop commented Jul 14, 2025

I'm going to implement this todo next

# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.

Related to #19925

Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactoring!

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@DarkLight1337 DarkLight1337 marked this pull request as ready for review July 15, 2025 03:33
@DarkLight1337
Copy link
Member Author

Had to do another round of refactoring, PTAL again

return build_output(pooled_data)


class StepPooler(BasePooler):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really fit in with the other pooling types so I made this its own class

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM!

@noooop
Copy link
Contributor

noooop commented Jul 15, 2025

Nice refactoring! +1
This implementation can simplify a lot of code

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@Isotr0py Isotr0py enabled auto-merge (squash) July 15, 2025 06:24
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 15, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@DarkLight1337 DarkLight1337 disabled auto-merge July 15, 2025 07:19
return None


def get_classification_activation_function(config: PretrainedConfig):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions aren't used anywhere outside of pooler.py so I moved them

"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules())
type(module).__name__ == "StepPooler" for module in model.modules())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks hacky. I'm planning to require models to define pooler as a BasePooler instance in the next PR so we can directly inspect model.pooler to get this information

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@noooop
Copy link
Contributor

noooop commented Jul 16, 2025

set MTEB_RERANK_TOL = 0.01

I will switch to a more robust test as soon as possible

@noooop
Copy link
Contributor

noooop commented Jul 16, 2025

We cannot rule out a potential numerical problem with this pr, but it would be very difficult to find it.

https://buildkite.com/vllm/ci/builds/24088/steps/canvas?sid=01981164-83cc-49e3-bc15-5b73f41a98ba
[2025-07-16T07:33:04Z] VLLM: torch.bfloat16 0.27263
[2025-07-16T07:33:04Z] SentenceTransformers: torch.float32 0.273
[2025-07-16T07:33:04Z] Difference: 0.00037000000000003697
[2025-07-16T07:33:05Z] PASSED

https://buildkite.com/vllm/ci/builds/23961/steps/canvas?sid=01980c3e-3f93-48c6-9c00-f25ec6b40f6a
[2025-07-15T06:42:55Z] VLLM: torch.bfloat16 0.27336
[2025-07-15T06:42:55Z] SentenceTransformers: torch.float32 0.273
[2025-07-15T06:42:55Z] Difference: -0.00035999999999997145
[2025-07-15T06:42:56Z] PASSED

this pr
[2025-07-16T06:24:12Z] VLLM: torch.bfloat16 0.26798
[2025-07-16T06:24:12Z] SentenceTransformers: torch.float32 0.273
[2025-07-16T06:24:12Z] Difference: 0.0050200000000000244
[2025-07-16T06:24:13Z] FAILED

@DarkLight1337
Copy link
Member Author

Yeah that is what I mean, trying to find out why the score decreased in this PR

@noooop
Copy link
Contributor

noooop commented Jul 16, 2025

models/language/pooling/test_mxbai_rerank.py
- A:      -0.00035999999999997145
- B:       0.00037000000000003697
- C:       0.00019000000000002348
- this pr: 0.0050200000000000244


tomaarsen/Qwen3-Reranker-0.6B-seq-cls
- A:       -1.0000000000010001e-05
- B:       -0.0005000000000000004
- C:       -0.00028000000000000247
- this pr: -0.000490000000000046


models/language/pooling/test_qwen3_reranker.py
- A:       -0.00036000000000002697
- B:       -0.00017000000000000348
- C:       -0.0009700000000000264
- this pr: -0.000490000000000046


BAAI/bge-reranker-v2-gemma
- A:       0.0013299999999999979
- B:       0.0013299999999999979
- C:       0.0013299999999999979
- this pr: 0.0006399999999999739


cross-encoder/ms-marco-TinyBERT-L-2-v2
- A: -3.999999999998449e-05
- B: -3.999999999998449e-05
- C: -3.999999999998449e-05
- this pr: -3.999999999998449e-05

BAAI/bge-reranker-base
- A: 2.999999999997449e-05
- B: 2.999999999997449e-05
- C: 2.999999999997449e-05
- this pr: 2.999999999997449e-05

Maybe there is indeed a potential numerical stability problem, and it does not affect the bert-like model,

BAAI/bge-reranker-v2-gemma is relatively stable and can be used for debugging

I think the difference is relatively small, we can first set MTEB_RERANK_TOL = 0.01 and then find it gradually .

@noooop
Copy link
Contributor

noooop commented Jul 16, 2025

There is a detail that may need pay attention. Because V1 does not support Gemma, V0 is used.

@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")

So is the numerical of v1 more unstable?

@DarkLight1337
Copy link
Member Author

I think I found that issue. On main branch, the data passed to the activation function is in float32, but in this PR it is in the models' dtype. Let's see if using float32 fixes the problem

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 16, 2025 11:05
@noooop
Copy link
Contributor

noooop commented Jul 16, 2025

buildkite/ci/pr/language-models-test-extended-pooling passed

BAAI/bge-reranker-v2-gemma
Difference: 0.0013299999999999979

The reranker test isn't completely useless

@DarkLight1337 DarkLight1337 merged commit 1c3198b into vllm-project:main Jul 16, 2025
64 checks passed
@DarkLight1337 DarkLight1337 deleted the consolidate-poolers branch July 16, 2025 13:39
nadathurv pushed a commit to nadathurv/vllm that referenced this pull request Jul 16, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jul 17, 2025
### What this PR does / why we need it?
- Fix broken commit by
[#20927](vllm-project/vllm#20927)
- Fix broken commit by
[#20466](vllm-project/vllm#20466)
- TODO: more fully adapt to the upstream reconstruction, let's first
make CI happy

- vLLM version: v0.9.2
- vLLM main:
vllm-project/vllm@11dfdf2

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants