Skip to content

Commit 3d443c6

Browse files
committed
Support bge-m3 sparse embeddings
Now with the pooling task framework Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
1 parent 67c153b commit 3d443c6

File tree

6 files changed

+158
-7
lines changed

6 files changed

+158
-7
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,8 +1366,25 @@ def to_pooling_params(self):
13661366

13671367
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
13681368

1369-
PoolingCompletionRequest = EmbeddingCompletionRequest
1370-
PoolingChatRequest = EmbeddingChatRequest
1369+
1370+
class PoolingCompletionRequest(EmbeddingCompletionRequest):
1371+
task: Optional[str] = None
1372+
1373+
def to_pooling_params(self):
1374+
return PoolingParams(dimensions=self.dimensions,
1375+
normalize=self.normalize,
1376+
task=self.task)
1377+
1378+
1379+
class PoolingChatRequest(EmbeddingChatRequest):
1380+
task: Optional[str] = None
1381+
1382+
def to_pooling_params(self):
1383+
return PoolingParams(dimensions=self.dimensions,
1384+
normalize=self.normalize,
1385+
task=self.task)
1386+
1387+
13711388
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
13721389

13731390

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ async def create_pooling(
140140
pooling_params = request.to_pooling_params()
141141

142142
try:
143-
pooling_params.verify("encode", self.model_config)
143+
task = request.task if request.task is not None else "encode"
144+
pooling_params.verify(task, self.model_config)
144145
except ValueError as e:
145146
return self.create_error_response(str(e))
146147

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
174174
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
175175
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
176+
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
176177
# [Multimodal]
177178
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
178179
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),

vllm/model_executor/models/roberta.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import itertools
45
from collections.abc import Iterable
56
from typing import Optional, Union
67

78
import torch
89
from torch import nn
910
from transformers import RobertaConfig
1011

11-
from vllm.config import VllmConfig
12+
from vllm.config import PoolerConfig, VllmConfig
1213
from vllm.forward_context import get_forward_context
1314
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
14-
DispatchPooler, Pooler)
15+
DispatchPooler, Pooler,
16+
PoolerOutput, PoolingMetadata,
17+
PoolingParamsUpdate,
18+
PoolingTask, build_output)
1519
from vllm.model_executor.layers.vocab_parallel_embedding import (
1620
VocabParallelEmbedding)
21+
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
22+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1723
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
1824
BertEmbeddingModel, BertModel,
1925
_decode_token_type_ids,
2026
_encode_token_type_ids)
2127
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
2228
maybe_prefix)
2329
from vllm.sequence import IntermediateTensors
30+
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2431

2532
from .bert_with_rope import BertWithRope, JinaRobertaModel
2633
from .interfaces import SupportsCrossEncoding, default_pooling_type
@@ -150,6 +157,130 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
150157
return loader.load_weights(weights_list, mapper=mapper)
151158

152159

160+
class M3SparsePooler(Pooler):
161+
"""A pooler that implements M3 sparse pooling
162+
163+
This layer does the following:
164+
1. By default returns dense embeddings.
165+
2. If the pooling params "additional_data" contain
166+
"sparse_embeddings", return sparse embeddings
167+
168+
Attributes:
169+
dense_pooler: The default pooler.
170+
sparse_linear: the linear module applied to the
171+
logits to obtain the token weights
172+
bos_token_id and eos_token_id: The special tokens
173+
inserted by the tokenizer. These are removed for
174+
sparse embeddings
175+
"""
176+
177+
def __init__(self, sparse_linear: nn.Module, bos_token_id: int,
178+
eos_token_id: int) -> None:
179+
super().__init__()
180+
self.sparse_linear = sparse_linear
181+
self.bos_token_id = bos_token_id
182+
self.eos_token_id = eos_token_id
183+
184+
def get_supported_tasks(self) -> set[PoolingTask]:
185+
return {"embed-sparse"}
186+
187+
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
188+
return PoolingParamsUpdate(requires_token_ids=True)
189+
190+
def forward(
191+
self,
192+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
193+
pooling_metadata: PoolingMetadata,
194+
) -> PoolerOutput:
195+
196+
assert isinstance(pooling_metadata, V1PoolingMetadata), \
197+
"BGE-M3 sparse embeddding are only support with V1"
198+
assert isinstance(hidden_states, list)
199+
200+
pooled_outputs = []
201+
202+
for i, hidden_state in enumerate(hidden_states):
203+
pooled_data = torch.squeeze(torch.relu(
204+
self.sparse_linear(hidden_state)),
205+
dim=0)
206+
token_ids = pooling_metadata.prompt_token_ids[
207+
i, :pooling_metadata.prompt_lens[i]]
208+
if token_ids[0] == self.bos_token_id:
209+
pooled_data = pooled_data[1:]
210+
if token_ids[-1] == self.eos_token_id:
211+
pooled_data = pooled_data[:-1]
212+
pooled_outputs.append(pooled_data)
213+
214+
return PoolerOutput(outputs=build_output(pooled_outputs))
215+
216+
217+
def filter_secondary_weights(
218+
all_weights: Iterable[tuple[str, torch.Tensor]],
219+
secondary_weights: list[str],
220+
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
221+
torch.Tensor]]]:
222+
all_weights1, all_weights2 = itertools.tee(all_weights)
223+
224+
def filtered(n):
225+
return any(n.startswith(f) for f in secondary_weights)
226+
227+
return ((n, w) for n, w in all_weights1 if filtered(n)), \
228+
((n, w) for n, w in all_weights2 if not filtered(n))
229+
230+
231+
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
232+
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
233+
234+
This class supports loading an additional sparse_linear.pt file
235+
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
236+
"""
237+
238+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
239+
240+
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
241+
242+
self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id
243+
self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id
244+
245+
super().__init__(vllm_config=vllm_config, prefix=prefix)
246+
self.secondary_weight_prefix = "sparse_linear."
247+
248+
self.secondary_weights = [
249+
DefaultModelLoader.Source(
250+
model_or_path=vllm_config.model_config.model,
251+
revision=None,
252+
prefix=self.secondary_weight_prefix,
253+
allow_patterns_overrides=["sparse_linear.pt"])
254+
]
255+
256+
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
257+
self.sparse_linear = nn.Linear(self.hidden_size, 1)
258+
return DispatchPooler({
259+
"encode":
260+
Pooler.for_encode(pooler_config),
261+
"embed":
262+
Pooler.for_embed(pooler_config),
263+
"embed-sparse":
264+
M3SparsePooler(self.sparse_linear, self.bos_token_id,
265+
self.eos_token_id),
266+
})
267+
268+
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
269+
secondary, weights = filter_secondary_weights(
270+
all_weights, [self.secondary_weight_prefix])
271+
272+
super().load_weights(weights)
273+
274+
params_dict = dict(self.named_parameters())
275+
276+
for name, loaded_weight in secondary:
277+
if name.startswith(self.secondary_weight_prefix):
278+
param = params_dict[name]
279+
weight_loader = getattr(param, "weight_loader",
280+
default_weight_loader)
281+
weight_loader(param, loaded_weight)
282+
283+
153284
@default_pooling_type("CLS")
154285
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
155286
"""A model that uses Roberta to provide embedding functionalities.

vllm/pooling_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def all_parameters(self) -> list[str]:
6262
def valid_parameters(self):
6363
return {
6464
"embed": ["dimensions", "normalize"],
65+
"embed-sparse": ["dimensions", "normalize"],
6566
"classify": ["activation"],
6667
"score": ["activation"],
6768
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
@@ -111,7 +112,7 @@ def _merge_default_parameters(self,
111112
setattr(self, k, getattr(pooler_config, k))
112113

113114
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
114-
if self.task == "embed":
115+
if self.task in ["embed", "embed-sparse"]:
115116
if self.normalize is None:
116117
self.normalize = True
117118

vllm/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
GenerationTask = Literal["generate", "transcription"]
66
GENERATION_TASKS = get_args(GenerationTask)
77

8-
PoolingTask = Literal["encode", "embed", "classify", "score"]
8+
PoolingTask = Literal["encode", "embed", "embed-sparse", "classify", "score"]
99
POOLING_TASKS = get_args(PoolingTask)
1010

1111
SupportedTask = Literal[GenerationTask, PoolingTask]

0 commit comments

Comments
 (0)