Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds method to read the pooling types from model's files #9506

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e1ff0f7
Adds method to read the pooling types from model's files
flaviabeo Oct 18, 2024
5bc9a7d
Adds MEAN pooling type
flaviabeo Oct 21, 2024
7119bb3
Make normalize variable return bool value
flaviabeo Oct 21, 2024
5b0a9f3
Adds test for model loading with the params
flaviabeo Oct 21, 2024
d16eefd
Adds method and attribute for bert sentence_transformer config files
maxdebayser Oct 22, 2024
6315c33
Adds other file names for the bert models config
flaviabeo Oct 24, 2024
1bcd3e8
fix loading of non-bert models and fix tests
maxdebayser Oct 24, 2024
69222e4
Extra check for if the files exists
flaviabeo Oct 25, 2024
32ee574
Reverts whitespaces
flaviabeo Oct 25, 2024
0b948a4
Adds pooling config as engine CLI arg
flaviabeo Oct 27, 2024
c3166f1
add pooling_config to models with a Pooler layer
maxdebayser Oct 28, 2024
16bcacd
Fixes tests, splits the pooling config params in type and norm
flaviabeo Oct 28, 2024
2cd2450
Moves get_pooling_type logic to ModelConfig
flaviabeo Oct 28, 2024
9c32660
Method to treat the pooling name string from file
flaviabeo Oct 28, 2024
ae73f4b
Format linting
flaviabeo Oct 28, 2024
02195c8
wip
flaviabeo Oct 28, 2024
37167ff
test
flaviabeo Oct 28, 2024
3d36a8c
Fixes and Exception for not supported pooling types
flaviabeo Oct 28, 2024
fbcd540
Fixing lint
flaviabeo Oct 29, 2024
c268d89
Fixing linting
flaviabeo Oct 29, 2024
5627d2f
Merge branch 'upstream_main'
flaviabeo Oct 30, 2024
0da7979
Lint
flaviabeo Oct 30, 2024
4531c33
Merge branch 'upstream_main'
flaviabeo Oct 31, 2024
8df9d63
Fix merge conflicts
flaviabeo Oct 31, 2024
4ac1f20
Review changes requested
flaviabeo Nov 1, 2024
ff9705b
Simplify pooler config init
flaviabeo Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected


def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'


@pytest.mark.parametrize(
("arg"),
[
Expand Down
45 changes: 45 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")


def test_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to add a verification to make sure that the pooling layer is configured correctly


model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.bert_config["max_seq_length"] == 512
assert model_config.bert_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
# assert output
assert output
65 changes: 65 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import PoolingType


@pytest.mark.parametrize(("model_id", "expected_task"), [
Expand Down Expand Up @@ -102,6 +103,70 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name


def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name


def test_get_bert_tokenization_sentence_transformer_config():
bge_model_config = ModelConfig(
model="BAAI/bge-base-en-v1.5",
task="auto",
tokenizer="BAAI/bge-base-en-v1.5",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

bert_bge_model_config = bge_model_config._get_bert_config()

assert bert_bge_model_config["max_seq_length"] == 512
assert bert_bge_model_config["do_lower_case"]


def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
Expand Down
34 changes: 28 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)

Expand Down Expand Up @@ -186,6 +187,7 @@ def __init__(
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.bert_config = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -218,7 +220,8 @@ def __init__(
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
spec_target_max_model_len=spec_target_max_model_len,
bert_config=self.bert_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -262,6 +265,10 @@ def _init_multimodal_config(

return None

def _get_bert_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)

def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
Expand All @@ -271,9 +278,20 @@ def _init_pooler_config(
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
pooling_type_ = pooling_type
normalize_ = pooling_norm
pooling_config = get_pooling_config(self.model, self.revision)
if pooling_config is not None:
pooling_type_ = pooling_config["pooling_type"]
normalize_ = pooling_config["normalize"]
# override if user specifies pooling_type and/or pooling_norm
if pooling_type is not None:
pooling_type_ = pooling_type
if pooling_norm is not None:
normalize_ = pooling_norm
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_type=pooling_type_,
pooling_norm=normalize_,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
Expand Down Expand Up @@ -1764,6 +1782,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
bert_config: Optional[Any] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1846,6 +1865,9 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if bert_config and "max_seq_length" in bert_config:
derived_max_model_len = bert_config["max_seq_length"]

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -850,7 +851,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
choices=[pt.name for pt in PoolingType],
default=None,
help='Used to configure the pooling method in the embedding model.'
)
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PoolingType(IntEnum):
ALL = 1
CLS = 2
STEP = 3
MEAN = 4


class Pooler(nn.Module):
Expand All @@ -27,7 +28,7 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.

Attributes:
pooling_type: The type of pooling to use (LAST, ALL, CLS).
pooling_type: The type of pooling to use (LAST, ALL, CLS, MEAN).
normalize: Whether to normalize the pooled data.
"""

Expand Down Expand Up @@ -97,6 +98,17 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
pooled_data = (
cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
Expand Down
Loading