-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Model] Add classification Task with Qwen2ForSequenceClassification #9704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
f31d0f2
classification compatible with debugginglogs.
jason9693 1bcdf25
fixed prefill error
jason9693 faff871
add test code
jason9693 568c2f9
remove unnecessary print and codes
jason9693 20faf46
remove unnecessary print, modifiied pooling logic.
kakao-kevin-us 4afa7e1
modified auto_cls logic, and lint check
jason9693 8546773
remve unnecessary print
jason9693 bde37c2
make docstring accurate
jason9693 658176f
add supported models
jason9693 f2ee1e2
modified docs
jason9693 18cf269
add AutoWeightsLoader loading
jason9693 65d3f50
move test code under embedding
jason9693 3374de6
remove unnecessary code and update info
jason9693 cc2a9ad
modified for linting
jason9693 81bad15
revert softmax inside the pooledr
jason9693 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Compare the outputs of HF and vLLM when using greedy sampling. | ||
|
||
This test only tests small models. Big models such as 7B should be tested from | ||
test_big_models.py because it could use a larger instance to run tests. | ||
|
||
Run `pytest tests/models/test_cls_models.py`. | ||
""" | ||
import pytest | ||
import torch | ||
from transformers import AutoModelForSequenceClassification | ||
|
||
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"] | ||
|
||
|
||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_classification_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
with hf_runner(model, | ||
dtype=dtype, | ||
auto_cls=AutoModelForSequenceClassification) as hf_model: | ||
hf_outputs = hf_model.classify(example_prompts) | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_outputs = vllm_model.classify(example_prompts) | ||
|
||
print(hf_outputs, vllm_outputs) | ||
|
||
# check logits difference | ||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): | ||
hf_output = torch.tensor(hf_output) | ||
vllm_output = torch.tensor(vllm_output) | ||
|
||
assert torch.allclose(hf_output, vllm_output, 1e-3) | ||
|
||
|
||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_classification_model_print( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
# This test is for verifying whether the model's extra_repr | ||
# can be printed correctly. | ||
print(vllm_model.model.llm_engine.model_executor.driver_worker. | ||
model_runner.model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# coding=utf-8 | ||
# Adapted from | ||
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py | ||
# Copyright 2024 Kakao Corp. (Kanana-X Team) | ||
# Copyright 2024 The Qwen team. | ||
# Copyright 2023 The vLLM team. | ||
"""Inference-only Qwen2-Classification model compatible with HF weights.""" | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from transformers import Qwen2Config | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.config import CacheConfig, LoRAConfig | ||
from vllm.model_executor.layers.linear import RowParallelLinear | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.models.qwen2 import Qwen2Model | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.sequence import IntermediateTensors, PoolerOutput | ||
|
||
from .utils import AutoWeightsLoader | ||
|
||
|
||
class Qwen2ForSequenceClassification(nn.Module): | ||
packed_modules_mapping = { | ||
"qkv_proj": [ | ||
"q_proj", | ||
"k_proj", | ||
"v_proj", | ||
], | ||
"gate_up_proj": [ | ||
"gate_proj", | ||
"up_proj", | ||
], | ||
} | ||
|
||
# LoRA specific attributes | ||
supported_lora_modules = [ | ||
"qkv_proj", | ||
"o_proj", | ||
"gate_up_proj", | ||
"down_proj", | ||
] | ||
embedding_modules = {} | ||
embedding_padding_modules = [] | ||
|
||
def __init__( | ||
self, | ||
config: Qwen2Config, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
lora_config: Optional[LoRAConfig] = None, | ||
) -> None: | ||
# TODO (@robertgshaw2): see if this can be moved out | ||
if (cache_config.sliding_window is not None | ||
and hasattr(config, "max_window_layers")): | ||
raise ValueError("Sliding window for some but all layers is not " | ||
"supported. This model uses sliding window " | ||
"but `max_window_layers` = %s is less than " | ||
"`num_hidden_layers` = %s. Please open an issue " | ||
"to discuss this feature." % ( | ||
config.max_window_layers, | ||
config.num_hidden_layers, | ||
)) | ||
|
||
super().__init__() | ||
|
||
self.config = config | ||
self.lora_config = lora_config | ||
|
||
self.quant_config = quant_config | ||
self.model = Qwen2Model(config, cache_config, quant_config) | ||
|
||
self.score = RowParallelLinear(config.hidden_size, | ||
config.num_labels, | ||
quant_config=quant_config) | ||
self._pooler = Pooler(pooling_type=PoolingType.LAST, | ||
normalize=False, | ||
softmax=True) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
) -> torch.Tensor: | ||
hidden_states = self.model(input_ids, positions, kv_caches, | ||
attn_metadata, intermediate_tensors) | ||
logits, _ = self.score(hidden_states) | ||
return logits | ||
|
||
def pooler( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return self._pooler(hidden_states, pooling_metadata) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
loader = AutoWeightsLoader(self, | ||
ignore_unexpected_prefixes=["lm_head."]) | ||
loader.load_weights(weights) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.