Skip to content

Commit 19332c0

Browse files
authored
[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent a55cf41 commit 19332c0

File tree

14 files changed

+166
-61
lines changed

14 files changed

+166
-61
lines changed

tests/models/language/pooling/mteb_utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import pytest
1111
import requests
12+
import torch
1213

1314
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
1415
check_embeddings_close)
@@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
165166
vllm_extra_kwargs=None,
166167
hf_model_callback=None,
167168
atol=MTEB_EMBED_TOL):
169+
# A model family has many models with the same architecture,
170+
# and we don't need to test each one.
168171
if not model_info.enable_test:
169-
# A model family has many models with the same architecture,
170-
# and we don't need to test each one.
171172
pytest.skip("Skipping test.")
172173

173-
example_prompts = ["The chef prepared a delicious meal."]
174+
# Test embed_dims, isnan and whether to use normalize
175+
example_prompts = ["The chef prepared a delicious meal." * 1000]
174176

177+
# Allow vllm to test using the given dtype, such as float32
175178
vllm_extra_kwargs = vllm_extra_kwargs or {}
176179
vllm_extra_kwargs["dtype"] = model_info.dtype
177180

181+
# Allow vllm to test using hf_overrides
178182
if model_info.hf_overrides is not None:
179183
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
180184

@@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
186190

187191
model_config = vllm_model.llm.llm_engine.model_config
188192

193+
# Confirm whether vllm is using the correct architecture
189194
if model_info.architecture:
190195
assert model_info.architecture in model_config.architectures
196+
197+
# Confirm whether vllm uses the correct default_pooling_type, which
198+
# relates to whether chunked prefill and prefix caching are enabled
191199
assert (model_config._model_info.default_pooling_type ==
192200
model_info.default_pooling_type)
193201

194202
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
195203
MTEB_EMBED_TASKS)
196204
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
197-
vllm_outputs = vllm_model.embed(example_prompts)
198205

206+
# Test embed_dims, isnan and whether to use normalize
207+
vllm_outputs = vllm_model.embed(example_prompts,
208+
truncate_prompt_tokens=-1)
209+
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
210+
211+
# Accelerate mteb test by setting
212+
# SentenceTransformers mteb score to a constant
199213
if model_info.mteb_score is None:
200214
with hf_runner(model_info.name,
201215
is_sentence_transformer=True,
202216
dtype="float32") as hf_model:
203217

218+
# e.g. setting default parameters for the encode method of hf_runner
204219
if hf_model_callback is not None:
205220
hf_model_callback(hf_model)
206221

@@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
299314
hf_model_callback=None,
300315
vllm_mteb_encoder=VllmMtebEncoder,
301316
atol=MTEB_RERANK_TOL):
317+
# A model family has many models with the same architecture,
318+
# and we don't need to test each one.
302319
if not model_info.enable_test:
303-
# A model family has many models with the same architecture,
304-
# and we don't need to test each one.
305320
pytest.skip("Skipping test.")
306321

322+
# Allow vllm to test using the given dtype, such as float32
307323
vllm_extra_kwargs = vllm_extra_kwargs or {}
308324
vllm_extra_kwargs["dtype"] = model_info.dtype
309325

326+
# Allow vllm to test using hf_overrides
310327
if model_info.hf_overrides is not None:
311328
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
312329

@@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
319336

320337
model_config = vllm_model.llm.llm_engine.model_config
321338

339+
# Confirm whether vllm is using the correct architecture
322340
if model_info.architecture:
323341
assert (model_info.architecture in model_config.architectures)
342+
343+
# Score API is only enabled for num_labels == 1
324344
assert model_config.hf_config.num_labels == 1
345+
346+
# Confirm whether vllm uses the correct default_pooling_type, which
347+
# relates to whether chunked prefill and prefix caching are enabled
325348
assert (model_config._model_info.default_pooling_type ==
326349
model_info.default_pooling_type)
327350

@@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
330353
languages=MTEB_RERANK_LANGS)
331354
vllm_dtype = model_config.dtype
332355

356+
# Accelerate mteb test by setting
357+
# SentenceTransformers mteb score to a constant
333358
if model_info.mteb_score is None:
334359
st_main_score, st_dtype = mteb_test_rerank_models_hf(
335360
hf_runner, model_info.name, hf_model_callback)

tests/models/language/pooling/test_bge_reranker_v2_gemma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RERANK_MODELS = [
1515
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
1616
architecture="GemmaForSequenceClassification",
17+
mteb_score=0.33757,
1718
hf_overrides={
1819
"architectures":
1920
["GemmaForSequenceClassification"],

vllm/config/__init__.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
745745

746746
self.pooler_config = self._init_pooler_config()
747747

748-
self.dtype = _get_and_verify_dtype(
748+
self.dtype: torch.dtype = _get_and_verify_dtype(
749749
self.model,
750750
self.hf_config,
751751
self.dtype,
@@ -1751,6 +1751,32 @@ def use_pad_token(self) -> bool:
17511751
# `llm as reranker` models defaults to not using pad_token.
17521752
return getattr(self.hf_config, "use_pad_token", True)
17531753

1754+
@property
1755+
def head_dtype(self) -> torch.dtype:
1756+
"""
1757+
"head" refers to the last Linear layer(s) of an LLM,
1758+
such as the lm_head in a generation model,
1759+
or the score or classifier in a classification model.
1760+
1761+
The default head_dtype based on runner_type.\n
1762+
- The pooling model defaults to using fp32 head,
1763+
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
1764+
- The generate model defaults to not using fp32 head,
1765+
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
1766+
"""
1767+
head_dtype = _get_head_dtype(config=self.hf_config,
1768+
dtype=self.dtype,
1769+
runner_type=self.runner_type)
1770+
1771+
if head_dtype not in current_platform.supported_dtypes:
1772+
logger.warning_once(
1773+
"The current platform does not support [%s] head dtype, "
1774+
"fallback to model dtype [%s].", head_dtype, self.dtype)
1775+
return self.dtype
1776+
1777+
logger.debug_once("head dtype: %s", head_dtype)
1778+
return head_dtype
1779+
17541780
def get_and_verify_max_len(self, max_model_len: int):
17551781
# Consider max_model_len in tokenizer_config only when
17561782
# pooling models use absolute position_embedding.
@@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
28932919
return torch_dtype
28942920

28952921

2922+
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
2923+
runner_type: str) -> torch.dtype:
2924+
head_dtype: Optional[Union[str,
2925+
torch.dtype]] = getattr(config, "head_dtype",
2926+
None)
2927+
2928+
if head_dtype == "model":
2929+
return dtype
2930+
elif isinstance(head_dtype, str):
2931+
head_dtype = head_dtype.lower()
2932+
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
2933+
raise ValueError(f"Unknown dtype: {head_dtype!r}")
2934+
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
2935+
elif isinstance(head_dtype, torch.dtype):
2936+
return head_dtype
2937+
elif head_dtype is None:
2938+
if torch.float32 not in current_platform.supported_dtypes:
2939+
return dtype
2940+
if runner_type == "pooling":
2941+
return torch.float32
2942+
return dtype
2943+
else:
2944+
raise ValueError(f"Unknown dtype: {head_dtype}")
2945+
2946+
28962947
def _get_and_verify_max_len(
28972948
hf_config: PretrainedConfig,
28982949
tokenizer_config: Optional[dict],

vllm/model_executor/layers/pooler.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66
from enum import IntEnum
77
from itertools import groupby
8-
from typing import Callable, Optional, TypeVar, Union, cast
8+
from typing import Callable, Optional, TypeVar, Union
99

1010
import torch
1111
import torch.nn as nn
@@ -362,14 +362,13 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
362362
class PoolerNormalize(PoolerActivation):
363363

364364
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
365-
x = F.normalize(pooled_data.float(), p=2, dim=-1)
366-
return x.to(pooled_data.dtype)
365+
return F.normalize(pooled_data, p=2, dim=-1)
367366

368367

369368
class PoolerMultiLabelClassify(PoolerActivation):
370369

371370
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
372-
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
371+
return F.sigmoid(pooled_data)
373372

374373

375374
class PoolerClassify(PoolerActivation):
@@ -394,9 +393,9 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
394393
pooled_data.shape[-1])
395394

396395
if num_labels < 2:
397-
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
396+
return F.sigmoid(pooled_data)
398397

399-
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
398+
return F.softmax(pooled_data, dim=-1)
400399

401400

402401
class LambdaPoolerActivation(PoolerActivation):
@@ -432,8 +431,9 @@ def __init__(self) -> None:
432431
from vllm.model_executor.models.adapters import _load_st_projector
433432

434433
vllm_config = get_current_vllm_config()
435-
self.projector = _load_st_projector(
434+
self.projector: Optional[nn.Module] = _load_st_projector(
436435
vllm_config.model_config) if vllm_config else None
436+
self.head_dtype = vllm_config.model_config.head_dtype
437437

438438
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
439439
pooling_metadata: PoolingMetadata):
@@ -442,16 +442,11 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
442442
pooled_data = torch.stack(pooled_data)
443443
# pooled_data shape: [batchsize, hidden_dimension]
444444

445+
pooled_data = pooled_data.to(self.head_dtype)
446+
445447
# Apply ST projector
446448
if self.projector is not None:
447-
projector = cast(nn.Module, self.projector)
448-
449-
def _proj(x: torch.Tensor) -> torch.Tensor:
450-
orig_dtype = x.dtype
451-
y = projector(x.to(torch.float32))
452-
return y.to(orig_dtype)
453-
454-
pooled_data = _proj(pooled_data)
449+
pooled_data = self.projector(pooled_data)
455450
# pooled_data shape: [batchsize, embedding_dimension]
456451

457452
pooling_params = get_pooling_params(pooling_metadata)
@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
494489
def __init__(self) -> None:
495490
super().__init__(activation=PoolerClassify(static_num_labels=False))
496491

492+
from vllm.config import get_current_vllm_config
493+
vllm_config = get_current_vllm_config()
494+
self.head_dtype = vllm_config.model_config.head_dtype
495+
497496
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
498497
pooling_metadata: PoolingMetadata):
498+
499+
if isinstance(pooled_data, list):
500+
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
501+
else:
502+
pooled_data = pooled_data.to(self.head_dtype)
503+
499504
pooling_params = get_pooling_params(pooling_metadata)
500505

501506
# for softmax
@@ -641,6 +646,7 @@ def __init__(
641646
self.act_fn = act_fn or PoolerClassify()
642647
self.logit_bias: Optional[
643648
float] = vllm_config.model_config.pooler_config.logit_bias
649+
self.head_dtype = vllm_config.model_config.head_dtype
644650

645651
def get_supported_tasks(self) -> Set[PoolingTask]:
646652
return {"classify", "score"}
@@ -655,6 +661,8 @@ def forward(
655661
pooled_data = torch.stack(pooled_data)
656662
# pooled_data shape: [batchsize, hidden_size]
657663

664+
pooled_data = pooled_data.to(self.head_dtype)
665+
658666
if self.classifier is not None:
659667
pooled_data = self.classifier(pooled_data)
660668
# pooled_data shape: [batchsize, num_labels]

vllm/model_executor/models/adapters.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
6262
linear = nn.Linear(layer_config.get("in_features", 768),
6363
layer_config.get("out_features", 768),
6464
bias=layer_config.get("bias", True),
65-
dtype=torch.float32)
65+
dtype=model_config.head_dtype)
6666

6767
if not _load_dense_weights(linear, folder, model_config):
6868
continue
6969

7070
layers.append(linear)
7171
if act_name := layer_config.get("activation_function"):
7272
layers.append(get_act_fn(act_name))
73-
return nn.Sequential(*layers).to(dtype=torch.float32)
73+
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
7474
except Exception:
7575
logger.exception("ST projector loading failed")
7676

@@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
105105
if weight_key in state_dict:
106106
weight_loader = getattr(linear.weight, "weight_loader",
107107
default_weight_loader)
108-
weight_loader(linear.weight,
109-
state_dict[weight_key].to(torch.float32))
108+
weight_loader(linear.weight, state_dict[weight_key])
110109

111110
bias_key = weight_key.replace("weight", "bias")
112111
if linear.bias is not None and bias_key in state_dict:
113112
bias_loader = getattr(linear.bias, "weight_loader",
114113
default_weight_loader)
115-
bias_loader(linear.bias,
116-
state_dict[bias_key].to(torch.float32))
114+
bias_loader(linear.bias, state_dict[bias_key])
117115
return True
118116
except Exception:
119117
logger.exception("Failed to load %s", filename)

vllm/model_executor/models/bert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
562562
self.bert = BertPoolingModel(vllm_config=vllm_config,
563563
prefix=maybe_prefix(prefix, "bert"),
564564
embedding_class=BertEmbedding)
565-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
565+
self.classifier = nn.Linear(config.hidden_size,
566+
config.num_labels,
567+
dtype=vllm_config.model_config.head_dtype)
566568

567569
pooler_config = vllm_config.model_config.pooler_config
568570
assert pooler_config is not None

vllm/model_executor/models/bert_with_rope.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
637637
self.new = GteNewModel(vllm_config=vllm_config,
638638
prefix=prefix,
639639
add_pooling_layer=True)
640-
self.classifier = RowParallelLinear(config.hidden_size,
641-
config.num_labels,
642-
input_is_parallel=False,
643-
bias=True,
644-
quant_config=quant_config,
645-
prefix=maybe_prefix(
646-
prefix, "classifier"),
647-
return_bias=False)
640+
self.classifier = ReplicatedLinear(
641+
config.hidden_size,
642+
config.num_labels,
643+
bias=True,
644+
quant_config=quant_config,
645+
params_dtype=vllm_config.model_config.head_dtype,
646+
prefix=maybe_prefix(prefix, "classifier"),
647+
return_bias=False)
648648

649649
pooler_config = vllm_config.model_config.pooler_config
650650
assert pooler_config is not None

vllm/model_executor/models/gpt2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
339339
config = vllm_config.model_config.hf_config
340340
self.transformer = GPT2Model(vllm_config=vllm_config,
341341
prefix=maybe_prefix(prefix, "gpt2"))
342-
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
342+
self.score = nn.Linear(config.n_embd,
343+
config.num_labels,
344+
bias=False,
345+
dtype=vllm_config.model_config.head_dtype)
343346

344347
pooler_config = vllm_config.model_config.pooler_config
345348
assert pooler_config is not None
@@ -348,7 +351,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
348351
"encode":
349352
Pooler.for_encode(pooler_config),
350353
"classify":
351-
Pooler.for_classify(pooler_config, classifier=None),
354+
Pooler.for_classify(pooler_config, classifier=self.score),
352355
})
353356

354357
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -367,8 +370,7 @@ def forward(
367370
position_ids=positions,
368371
inputs_embeds=inputs_embeds,
369372
intermediate_tensors=intermediate_tensors)
370-
logits = self.score(hidden_states)
371-
return logits
373+
return hidden_states
372374

373375

374376
def _add_transformer_prefix(

0 commit comments

Comments
 (0)