Skip to content

Commit f54f851

Browse files
authored
[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent d4d1a60 commit f54f851

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+786
-399
lines changed

examples/offline_inference/pooling/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
2626
python examples/offline_inference/pooling/embed_matryoshka_fy.py
2727
```
2828

29+
## Multi vector retrieval usage
30+
31+
```bash
32+
python examples/offline_inference/pooling/multi_vector_retrieval.py
33+
```
34+
2935
## Named Entity Recognition (NER) usage
3036

3137
```bash
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from argparse import Namespace
5+
6+
from vllm import LLM, EngineArgs
7+
from vllm.utils import FlexibleArgumentParser
8+
9+
10+
def parse_args():
11+
parser = FlexibleArgumentParser()
12+
parser = EngineArgs.add_cli_args(parser)
13+
# Set example specific arguments
14+
parser.set_defaults(
15+
model="BAAI/bge-m3",
16+
runner="pooling",
17+
enforce_eager=True,
18+
)
19+
return parser.parse_args()
20+
21+
22+
def main(args: Namespace):
23+
# Sample prompts.
24+
prompts = [
25+
"Hello, my name is",
26+
"The president of the United States is",
27+
"The capital of France is",
28+
"The future of AI is",
29+
]
30+
31+
# Create an LLM.
32+
# You should pass runner="pooling" for embedding models
33+
llm = LLM(**vars(args))
34+
35+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
36+
outputs = llm.embed(prompts)
37+
38+
# Print the outputs.
39+
print("\nGenerated Outputs:\n" + "-" * 60)
40+
for prompt, output in zip(prompts, outputs):
41+
embeds = output.outputs.embedding
42+
print(len(embeds))
43+
44+
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
45+
outputs = llm.encode(prompts, pooling_task="token_embed")
46+
47+
# Print the outputs.
48+
print("\nGenerated Outputs:\n" + "-" * 60)
49+
for prompt, output in zip(prompts, outputs):
50+
multi_vector = output.outputs.data
51+
print(multi_vector.shape)
52+
53+
54+
if __name__ == "__main__":
55+
args = parse_args()
56+
main(args)

examples/offline_inference/prithvi_geospatial_mae_io_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main():
4040
model_impl="terratorch",
4141
)
4242

43-
pooling_params = PoolingParams(task="encode", softmax=False)
43+
pooling_params = PoolingParams(task="token_classify", activation=False)
4444
pooler_output = llm.encode(
4545
img_prompt,
4646
pooling_params=pooling_params,

examples/online_serving/pooling/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py
1818
python examples/online_serving/pooling/jinaai_rerank_client.py
1919
```
2020

21+
## Multi vector retrieval usage
22+
23+
```bash
24+
python examples/online_serving/pooling/multi_vector_retrieval_client.py
25+
```
26+
2127
## Named Entity Recognition (NER) usage
2228

2329
```bash
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""
5+
Example online usage of Pooling API for multi vector retrieval.
6+
7+
Run `vllm serve <model> --runner pooling`
8+
to start up the server in vLLM. e.g.
9+
10+
vllm serve BAAI/bge-m3
11+
"""
12+
13+
import argparse
14+
15+
import requests
16+
import torch
17+
18+
19+
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
20+
headers = {"User-Agent": "Test Client"}
21+
response = requests.post(api_url, headers=headers, json=prompt)
22+
return response
23+
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--host", type=str, default="localhost")
28+
parser.add_argument("--port", type=int, default=8000)
29+
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
30+
31+
return parser.parse_args()
32+
33+
34+
def main(args):
35+
api_url = f"http://{args.host}:{args.port}/pooling"
36+
model_name = args.model
37+
38+
prompts = [
39+
"Hello, my name is",
40+
"The president of the United States is",
41+
"The capital of France is",
42+
"The future of AI is",
43+
]
44+
prompt = {"model": model_name, "input": prompts}
45+
46+
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
47+
for output in pooling_response.json()["data"]:
48+
multi_vector = torch.tensor(output["data"])
49+
print(multi_vector.shape)
50+
51+
52+
if __name__ == "__main__":
53+
args = parse_args()
54+
main(args)

tests/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,8 +1011,12 @@ def embed(
10111011
req_outputs = self.llm.embed(inputs, *args, **kwargs)
10121012
return [req_output.outputs.embedding for req_output in req_outputs]
10131013

1014-
def encode(self, prompts: list[str]) -> list[list[float]]:
1015-
req_outputs = self.llm.encode(prompts)
1014+
def token_embed(self, prompts: list[str]) -> list[list[float]]:
1015+
req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
1016+
return [req_output.outputs.data for req_output in req_outputs]
1017+
1018+
def token_classify(self, prompts: list[str]) -> list[list[float]]:
1019+
req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
10161020
return [req_output.outputs.data for req_output in req_outputs]
10171021

10181022
def reward(self, prompts: list[str]) -> list[list[float]]:

tests/entrypoints/pooling/llm/test_classify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_encode_api(llm: LLM):
6363
# chunked prefill does not support all pooling
6464
err_msg = "pooling_task must be one of.+"
6565
with pytest.raises(ValueError, match=err_msg):
66-
llm.encode(prompts, use_tqdm=False)
66+
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
6767

6868

6969
def test_score_api(llm: LLM):

tests/entrypoints/pooling/llm/test_embedding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def llm():
3535
cleanup_dist_env_and_memory()
3636

3737

38+
@pytest.mark.skip_global_cleanup
39+
def test_encode_api(llm: LLM):
40+
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
41+
multi_vector = outputs[0].outputs.data
42+
assert multi_vector.shape == (11, 384)
43+
44+
3845
def test_pooling_params(llm: LLM):
3946
def get_outputs(normalize):
4047
outputs = llm.embed(

tests/entrypoints/pooling/llm/test_encode.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
5757
]
5858

5959
# Multiple PoolingParams should be matched with each prompt
60-
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
60+
outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
6161
assert len(PROMPTS) == len(outputs)
6262

6363
# Exception raised, if the size of params does not match the size of prompts
6464
with pytest.raises(ValueError):
65-
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
65+
outputs = llm.encode(
66+
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
67+
)
6668

6769
# Single PoolingParams should be applied to every prompt
6870
single_pooling_params = PoolingParams()
69-
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
71+
outputs = llm.encode(
72+
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
73+
)
7074
assert len(PROMPTS) == len(outputs)
7175

7276
# pooling_params is None, default params should be applied
73-
outputs = llm.encode(PROMPTS, pooling_params=None)
77+
outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
7478
assert len(PROMPTS) == len(outputs)
7579

7680

tests/entrypoints/pooling/llm/test_reward.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,23 @@ def llm():
3636
cleanup_dist_env_and_memory()
3737

3838

39-
@pytest.mark.skip_global_cleanup
4039
def test_pooling_params(llm: LLM):
41-
def get_outputs(softmax):
40+
def get_outputs(activation):
4241
outputs = llm.reward(
43-
prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False
42+
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
4443
)
4544
return torch.cat([x.outputs.data for x in outputs])
4645

47-
default = get_outputs(softmax=None)
48-
w_softmax = get_outputs(softmax=True)
49-
wo_softmax = get_outputs(softmax=False)
46+
default = get_outputs(activation=None)
47+
w_activation = get_outputs(activation=True)
48+
wo_activation = get_outputs(activation=False)
5049

51-
assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax."
52-
assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), (
53-
"wo_softmax should not use softmax."
50+
assert torch.allclose(default, w_activation, atol=1e-2), (
51+
"Default should use activation."
5452
)
55-
assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), (
56-
"w_softmax should be close to softmax(wo_softmax)."
53+
assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
54+
"wo_activation should not use activation."
55+
)
56+
assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
57+
"w_activation should be close to activation(wo_activation)."
5758
)

0 commit comments

Comments
 (0)