Skip to content

[Model][4/N] Automatic conversion of CrossEncoding model #19675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Jun 16, 2025

TL;DR

  • vllm.model_executor.models.adapters.as_seq_cls_model
    • Theoretically, this PR could allow all *ForCausalLM models to automatically have *ForSequenceClassification implementation, and automatically support classify (classification) task and score (rerank) task.
    • simplify the code of the previous models
      • Qwen3ForSequenceClassification
    • New Model
      • GemmaForSequenceClassification
      • It should support any ForSequenceClassification models, but no suitable hf repository was found for testing.
  • converting2seq_cls_models.py
    • Offline convert ForCausalLM into ForSequenceClassification model.
    • from_2_way_softmax
      • Qwen/Qwen3-Reranker-0.6B
      • Qwen/Qwen3-Reranker-4B
      • Qwen/Qwen3-Reranker-8B
    • no_post_processing
      • BAAI/bge-reranker-v2-gemma
  • BAAI/bge-reranker-v2-gemma needs to directly concatenate the query and document without using pad_token to get exactly the same result as the official one.

Hope that after merging this pr, vllm can support more llms using the relevance generation method as classifiers and rerankers.

Usage

converting2seq_cls_models.py:

  • for Qwen3-Reranker
python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls

PTAL #19260

  • for BAAI/bge-reranker-v2-gemma
python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls

Caution

"Yes" and "yes" are two different tokens

# v1 temporarily will report an error
VLLM_USE_V1=0 vllm serve ./bge-reranker-v2-gemma-seq-cls --task score --served-model-name BAAI/bge-reranker-v2-gemma

requests demo + formating query & document:

import requests

url = "http://127.0.0.1:8000/score"
MODEL_NAME = "BAAI/bge-reranker-v2-gemma"

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
query_template = "A: {query}\n"
document_template = "B: {doc}\n{prompt}"

queries = [
    "What is the capital of China?",
    "Explain gravity",
]

documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

queries= [query_template.format(query=query) for query in queries ]
documents= [
    document_template.format(doc=doc, prompt=prompt)
    for doc in documents 
]

response = requests.post(url,
                         json={
                             "model": MODEL_NAME,
                             "text_1": queries,
                             "text_2": documents,
                             "truncate_prompt_tokens": -1,
                         }).json()

print(response)

expected output

{'id': 'score-30d577279e674c98ad1b0ee12f978b67', 'object': 'list', 'created': 1750361179, 'model': 'BAAI/bge-reranker-v2-gemma', 'data': [{'index': 0, 'object': 'score', 'score': 0.9998812675476074}, {'index': 1, 'object': 'score', 'score': 0.9997507929801941}], 'usage': {'prompt_tokens': 126, 'total_tokens': 126, 'completion_tokens': 0, 'prompt_tokens_details': None}}

If someone wants to implement an offline conversion from ForCausalLM to ForSequenceClassification support new methods or new models, please refer to

https://github.com/noooop/snippet/tree/main/converting2SequenceClassification

(I don't know where to place this code in vllm.)

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Follow-up #11469, Further improve #10674

  • Automatic conversion of CrossEncoding models

Test Plan

Test Result

(Optional) Documentation Update

Known issues

  1. needs refactoring
  • in get_model_architecture
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
    if model_config.task == "embed":
        model_cls = as_embedding_model(model_cls)
    elif model_config.task in ["classify", "score"]:
        model_cls = as_seq_cls_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
  • _ModelRegistry.is_cross_encoder_model not considered as_seq_cls_model
  • This pr actually allows all ForCausalLM to support corresponding ForSequenceClassification. Do we really need to list out all Auto-converted architectures?

# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),

  1. default_softmax=True in conflict with what in documentation (default_softmax=False when the task is score).

In vllm, the score task only supports num_labels == 1, while models with num_labels == 1 in sentence-transformers use Sigmoid by default.

https://github.com/UKPLab/sentence-transformers/blob/910ed144dfc0a08f31517b0d01580302015fa408/sentence_transformers/cross_encoder/CrossEncoder.py#L485-L487

        if self.config.num_labels == 1:
            return nn.Sigmoid()

Perhaps we should update the documentation to set default_softmax=True when the task is score, consistent with the implementation in sentence-transformers. And we should pin the sentence-transformers version to >= 4.1.0.

  1. change verify_and_update_config into a class method in the future and call it when initializing model_config. fix_by [Model][1/N] Automatic conversion of CrossEncoding model #20012

  2. Template aware prompt truncation to avoid cutting off important instructions.

  3. Alibaba-NLP/gte-Qwen2-1.5B-instruct & Alibaba-NLP/gte-modernbert-base

NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for FlashAttentionImpl

Fix #19673
Fix #20051

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @noooop, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the automatic model conversion mechanism to support both classification and scoring tasks under a unified as_seq_cls_model adapter. It updates the model loading utility to recognize the score task and applies the generic adapter. The adapter itself is refactored to handle pooling and scoring more flexibly, and the Qwen3 sequence classification model is updated to utilize this new generic adapter.

Highlights

  • Score Task Support: The automatic model conversion logic in the model loader (vllm/model_executor/model_loader/utils.py) is updated to use the as_seq_cls_model adapter for models specified with the score task, in addition to the existing classify task.
  • Adapter Refactoring: The as_seq_cls_model adapter in vllm/model_executor/models/adapters.py is refactored. The pooling and scoring logic is moved into a dedicated pooler method, allowing for more flexible handling of different pooling types within the adapter. It also adds a check to squeeze the output dimension for the score task.
  • Qwen3 Reranker Integration: The specific Qwen3ForSequenceClassification implementation is updated to inherit from the new generic as_seq_cls_model adapter, simplifying its structure and leveraging the shared adapter logic. Specific Qwen3 reranker configuration verification is moved into a new config_verify method.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the documentation Improvements or additions to documentation label Jun 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request focuses on enabling automatic conversion of score models by adapting the sequence classification model functionality. Key changes include:

  • Renaming as_classification_model to as_seq_cls_model and updating its functionality to support both "classify" and "score" tasks. This is consistently applied across documentation, tests, and model loading utilities.
  • Refactoring Qwen3ForSequenceClassification to leverage the new as_seq_cls_model adapter. This promotes code reuse and centralizes the classification/scoring logic.
  • Introducing a config_verify method in the adapter pattern, allowing model-specific configurations, which is well-utilized by Qwen3ForSequenceClassification for its reranker variant.
  • Ensuring that for "score" tasks, the model expects num_labels == 1 and the output is appropriately processed (squeezed).

The changes appear robust and improve the model adaptation framework. One area for potential clarification is the behavior of PoolingType.ALL within the as_seq_cls_model adapter, as noted in the specific comment.

Please also consider filling out the checklist in the PR description (Purpose, Test Plan, Test Result) for completeness.

@noooop noooop changed the title [Model] Automatic conversion of score models [Model] Automatic conversion of score (CrossEncoding) models Jun 16, 2025
@noooop noooop force-pushed the as_score_model branch 4 times, most recently from 6dc55ba to 00d377b Compare June 18, 2025 08:33
@noooop noooop force-pushed the as_score_model branch 2 times, most recently from aa22cad to 71b1df4 Compare June 18, 2025 10:49
@mergify mergify bot added the qwen Related to Qwen models label Jun 18, 2025
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 19, 2025
@noooop noooop closed this Jun 19, 2025
@noooop noooop reopened this Jun 19, 2025
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop
Copy link
Contributor Author

noooop commented Jun 19, 2025

@DarkLight1337

_ModelRegistry.is_cross_encoder_model not considered as_seq_cls_model

It seems I need to spend some more time fixing it.

@noooop noooop marked this pull request as draft June 19, 2025 11:53
@mergify mergify bot added the frontend label Jun 19, 2025
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop noooop marked this pull request as ready for review June 20, 2025 02:03
@noooop
Copy link
Contributor Author

noooop commented Jun 20, 2025

@DarkLight1337

I find it difficult to fully fix is_cross_encoder, and I am not familiar with this part of code.

So I will use a temporary solution to fix it and document in Known issues.

    def is_cross_encoder(self) -> bool:
        # Temporary solution, See #19675
        return (self.registry.is_cross_encoder_model(self.architectures) or
                "forsequenceclassification" in self.architectures[0].lower())

how did this issues occur

  • For the converted ./bge-reranker-v2-gemma-seq-cls model, its architecture is GemmaForSequenceClassification.

In vllm/model_executor/models/registry.py, it is routed to GemmaForCausalLM.

"GemmaForSequenceClassification": ("gemma", "GemmaForCausalLM"),

But _ModelRegistry.is_cross_encoder_model not considered as_seq_cls_model

so it derived that GemmaForSequenceClassification is_cross_encoder_model == False, which leads to a wrong calculation method

  • For the converted ./Qwen3-Reranker-0.6B-seq-cls model, its architecture is Qwen3ForSequenceClassification.

In vllm/model_executor/models/registry.py, it is routed to Qwen3ForSequenceClassification.

"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"),

so it derived that GemmaForSequenceClassification is_cross_encoder_model == True

  • Therefore, for any ForCausalLM, need add the following code in the model definition file.
XXXForSequenceClassification = as_seq_cls_model(XXXForCausalLM)

Add the following code in vllm/model_executor/models/registry.py

"XXXForSequenceClassification": ("xxx", "XXXForSequenceClassification"),

This is clearly too tedious.

noooop added 2 commits June 20, 2025 10:30
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
default_softmax=True,
)

class ModelForClassification(ModelForPooling):
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't is_cross_encoder_model return true for this model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the code block was not run in ModelRegistry.is_cross_encoder_model

model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
    model_cls = as_embedding_model(model_cls)
elif model_config.task in ["classify", "score"]:
    model_cls = as_seq_cls_model(model_cls)
elif model_config.task == "reward":
    model_cls = as_reward_model(model_cls)

_ModelRegistry.is_cross_encoder_model not considered as_seq_cls_model

Copy link
Contributor Author

@noooop noooop Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can fix it in two ways, I'm not sure which one you prefer.

  1. let _ModelRegistry.is_cross_encoder_model consider as_seq_cls_model, which requires passing a task parameter and modifying many interfaces.there might also be duplicate code.
  2. Automatically add for each ForCausalLM: XXXForSequenceClassification = as_seq_cls_model(XXXForCausalLM) and add "XXXForSequenceClassification": ("xxx", "XXXForSequenceClassification"). Although I haven't figured out how to do it yet.

Copy link
Member

@DarkLight1337 DarkLight1337 Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't quite get it. In this code the model is converted into ModelForSequenceClassification via as_seq_cls_model when --task score , which allows the model to be detected as is_cross_encoder_model, right?

Copy link
Contributor Author

@noooop noooop Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    @property
    def is_cross_encoder(self) -> bool:
        return self.registry.is_cross_encoder_model(self.architectures)  <- here
    def is_cross_encoder_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)  <- here
        return model_cls.supports_cross_encoding
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))  <- here 

    def load_model_cls(self) -> type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)
    @staticmethod
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            architecture=model.__name__,
            is_text_generation_model=is_text_generation_model(model),
            is_pooling_model=True,  # Can convert any model into a pooling model
            supports_cross_encoding=supports_cross_encoding(model),   <- here # this model expected as_seq_cls_model(GemmaForCausalLM), but actually is GemmaForCausalLM, 
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
            is_hybrid=is_hybrid(model),
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
            has_noops=has_noops(model),
        )

not using: model_cls = as_seq_cls_model(model_cls)

_ModelRegistry.is_cross_encoder_model not considered as_seq_cls_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When computing _resolve_task, it calls "if self.registry.is_cross_encoder_model(architectures)", so a circular reference might be formed.

    def _get_preferred_task(
        self,
        architectures: list[str],
        supported_tasks: set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if self.registry.is_cross_encoder_model(architectures):  <- here
            return "score"
        if self.registry.is_transcription_model(architectures):
            return "transcription"

        suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
            # Other models follow this pattern
            ("ForCausalLM", "generate"),
            ("ForConditionalGeneration", "generate"),
            ("ForSequenceClassification", "classify"),
            ("ChatModel", "generate"),
            ("LMHeadModel", "generate"),
            ("EmbeddingModel", "embed"),
            ("RewardModel", "reward"),
        ]
        _, arch = self.registry.inspect_model_cls(architectures)

        for suffix, pref_task in suffix_to_preferred_task:
            if arch.endswith(suffix) and pref_task in supported_tasks:
                return pref_task

        return None

so _resolve_task and _get_preferred_task, as well as inspect_model_cls, need to be refactored

┑( ̄Д  ̄)┍

Copy link
Contributor Author

@noooop noooop Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337

Can I solve this problem in another PR as it is very complex and exceeds the scope of this PR?


I think redirecting *ForSequenceClassification to *ForCausalLM makes things complicated.

e.g.

"GemmaForSequenceClassification": ("gemma", "GemmaForCausalLM"),

At this time, *ForCausalLM might be used for generate, classify, embed, score. Among them, embed and score are difficult to distinguish,need to pass the task parameter to distinguish. If parsing fails, incorrect calculation methods will be used and wrong results will be obtained.

should use the code below

"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"),

Then

*ForSequenceClassification prefers the classify & score task.

Even I think we should not distinguish between the classify & score tasks.

If there is a *ForSequenceClassification model, we should allow users to use both the classify API and the score API.

Their results should be the same, with the difference being whether users concatenate queries and documents or vllm does the concatenation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from vllm import LLM

model = LLM(model="BAAI/bge-reranker-base", task="score")

text_1 = "ping"
text_2 = "pong"

outputs = model.score(text_1, text_2)


print(outputs)


# [ScoringRequestOutput(request_id='0', outputs=ScoringOutput(score=0.77197265625), prompt_token_ids=[0, 33429, 2, 2, 114007, 2], finished=True)]


from vllm import LLM

model = LLM(model="BAAI/bge-reranker-base", task="classify")

text_1 = "ping"
text_2 = "pong"

# after changing the output dimensions slightly
outputs = model.classify([f'{text_1}</s></s>{text_2}'])

print(outputs)

# [ClassificationRequestOutput(request_id='0', outputs=ClassificationOutput(num_classes=1), prompt_token_ids=[0, 33429, 2, 2, 114007, 2], finished=True)]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help if we have a separate model class for --task score? Then if the model is *ForSequenceClassification, we first default to --task classify. So the user should explicitly set --task score in order to use it as a cross encoder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337

by adding

GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM)
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"),
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"),

Now there are no test failures and the code is not too dirty, let's merge this pr first.

Let us discuss and solve the issue completely in the next pr.

Signed-off-by: wang.yuqi <noooop@126.com>
noooop added 2 commits June 21, 2025 18:21
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this to work!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) June 23, 2025 03:19
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 23, 2025
@noooop
Copy link
Contributor Author

noooop commented Jun 23, 2025

@DarkLight1337

Now I have to deal with a big issue that will be handled in the next pr.

now all ForSequenceClassification models is_cross_encoder_model == True

def _get_preferred_task(
        .....
        if self.registry.is_cross_encoder_model(architectures):
            return "score"

the test below will fail:

 ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),

def test_auto_task(model_id, expected_runner_type, expected_task):
   ......

AssertionError: assert 'score' == 'classify'


because the impact is significant

I will first assume that this pr has been merged and prepare the next pr.

Once the solution of the next pr is approved, then merge this pr.

@DarkLight1337 DarkLight1337 disabled auto-merge June 23, 2025 05:40
@noooop noooop closed this Jun 24, 2025
@noooop noooop reopened this Jun 25, 2025
@noooop noooop changed the title [Model] Automatic conversion of score (CrossEncoding) models [Model][4/N] Automatic conversion of CrossEncoding model Jun 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation frontend qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: mxbai-rerank-large-v2 [New Model]: Support BAAI/bge-reranker-v2-gemma model
2 participants