|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
| 4 | +import itertools |
4 | 5 | from collections.abc import Iterable |
5 | 6 | from typing import Optional, Union |
6 | 7 |
|
7 | 8 | import torch |
8 | 9 | from torch import nn |
9 | 10 | from transformers import RobertaConfig |
10 | 11 |
|
11 | | -from vllm.config import VllmConfig |
| 12 | +from vllm.config import PoolerConfig, VllmConfig |
12 | 13 | from vllm.forward_context import get_forward_context |
13 | 14 | from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, |
14 | | - DispatchPooler, Pooler) |
| 15 | + DispatchPooler, Pooler, |
| 16 | + PoolerOutput, PoolingMetadata, |
| 17 | + PoolingParamsUpdate, |
| 18 | + PoolingTask, build_output) |
15 | 19 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
16 | 20 | VocabParallelEmbedding) |
| 21 | +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader |
| 22 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
17 | 23 | from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT, |
18 | 24 | BertEmbeddingModel, BertModel, |
19 | 25 | _decode_token_type_ids, |
20 | 26 | _encode_token_type_ids) |
21 | 27 | from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, |
22 | 28 | maybe_prefix) |
23 | 29 | from vllm.sequence import IntermediateTensors |
| 30 | +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata |
24 | 31 |
|
25 | 32 | from .bert_with_rope import BertWithRope, JinaRobertaModel |
26 | 33 | from .interfaces import SupportsCrossEncoding, default_pooling_type |
@@ -150,6 +157,130 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
150 | 157 | return loader.load_weights(weights_list, mapper=mapper) |
151 | 158 |
|
152 | 159 |
|
| 160 | +class M3SparsePooler(Pooler): |
| 161 | + """A pooler that implements M3 sparse pooling |
| 162 | +
|
| 163 | + This layer does the following: |
| 164 | + 1. By default returns dense embeddings. |
| 165 | + 2. If the pooling params "additional_data" contain |
| 166 | + "sparse_embeddings", return sparse embeddings |
| 167 | +
|
| 168 | + Attributes: |
| 169 | + dense_pooler: The default pooler. |
| 170 | + sparse_linear: the linear module applied to the |
| 171 | + logits to obtain the token weights |
| 172 | + bos_token_id and eos_token_id: The special tokens |
| 173 | + inserted by the tokenizer. These are removed for |
| 174 | + sparse embeddings |
| 175 | + """ |
| 176 | + |
| 177 | + def __init__(self, sparse_linear: nn.Module, bos_token_id: int, |
| 178 | + eos_token_id: int) -> None: |
| 179 | + super().__init__() |
| 180 | + self.sparse_linear = sparse_linear |
| 181 | + self.bos_token_id = bos_token_id |
| 182 | + self.eos_token_id = eos_token_id |
| 183 | + |
| 184 | + def get_supported_tasks(self) -> set[PoolingTask]: |
| 185 | + return {"embed-sparse"} |
| 186 | + |
| 187 | + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: |
| 188 | + return PoolingParamsUpdate(requires_token_ids=True) |
| 189 | + |
| 190 | + def forward( |
| 191 | + self, |
| 192 | + hidden_states: Union[torch.Tensor, list[torch.Tensor]], |
| 193 | + pooling_metadata: PoolingMetadata, |
| 194 | + ) -> PoolerOutput: |
| 195 | + |
| 196 | + assert isinstance(pooling_metadata, V1PoolingMetadata), \ |
| 197 | + "BGE-M3 sparse embeddding are only support with V1" |
| 198 | + assert isinstance(hidden_states, list) |
| 199 | + |
| 200 | + pooled_outputs = [] |
| 201 | + |
| 202 | + for i, hidden_state in enumerate(hidden_states): |
| 203 | + pooled_data = torch.squeeze(torch.relu( |
| 204 | + self.sparse_linear(hidden_state)), |
| 205 | + dim=0) |
| 206 | + token_ids = pooling_metadata.prompt_token_ids[ |
| 207 | + i, :pooling_metadata.prompt_lens[i]] |
| 208 | + if token_ids[0] == self.bos_token_id: |
| 209 | + pooled_data = pooled_data[1:] |
| 210 | + if token_ids[-1] == self.eos_token_id: |
| 211 | + pooled_data = pooled_data[:-1] |
| 212 | + pooled_outputs.append(pooled_data) |
| 213 | + |
| 214 | + return PoolerOutput(outputs=build_output(pooled_outputs)) |
| 215 | + |
| 216 | + |
| 217 | +def filter_secondary_weights( |
| 218 | + all_weights: Iterable[tuple[str, torch.Tensor]], |
| 219 | + secondary_weights: list[str], |
| 220 | +) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, |
| 221 | + torch.Tensor]]]: |
| 222 | + all_weights1, all_weights2 = itertools.tee(all_weights) |
| 223 | + |
| 224 | + def filtered(n): |
| 225 | + return any(n.startswith(f) for f in secondary_weights) |
| 226 | + |
| 227 | + return ((n, w) for n, w in all_weights1 if filtered(n)), \ |
| 228 | + ((n, w) for n, w in all_weights2 if not filtered(n)) |
| 229 | + |
| 230 | + |
| 231 | +class BgeM3EmbeddingModel(RobertaEmbeddingModel): |
| 232 | + """A model that extends RobertaEmbeddingModel with sparse embeddings. |
| 233 | +
|
| 234 | + This class supports loading an additional sparse_linear.pt file |
| 235 | + to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 |
| 236 | + """ |
| 237 | + |
| 238 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 239 | + |
| 240 | + self.hidden_size = vllm_config.model_config.hf_config.hidden_size |
| 241 | + |
| 242 | + self.bos_token_id = vllm_config.model_config.hf_config.bos_token_id |
| 243 | + self.eos_token_id = vllm_config.model_config.hf_config.eos_token_id |
| 244 | + |
| 245 | + super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 246 | + self.secondary_weight_prefix = "sparse_linear." |
| 247 | + |
| 248 | + self.secondary_weights = [ |
| 249 | + DefaultModelLoader.Source( |
| 250 | + model_or_path=vllm_config.model_config.model, |
| 251 | + revision=None, |
| 252 | + prefix=self.secondary_weight_prefix, |
| 253 | + allow_patterns_overrides=["sparse_linear.pt"]) |
| 254 | + ] |
| 255 | + |
| 256 | + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: |
| 257 | + self.sparse_linear = nn.Linear(self.hidden_size, 1) |
| 258 | + return DispatchPooler({ |
| 259 | + "encode": |
| 260 | + Pooler.for_encode(pooler_config), |
| 261 | + "embed": |
| 262 | + Pooler.for_embed(pooler_config), |
| 263 | + "embed-sparse": |
| 264 | + M3SparsePooler(self.sparse_linear, self.bos_token_id, |
| 265 | + self.eos_token_id), |
| 266 | + }) |
| 267 | + |
| 268 | + def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]): |
| 269 | + secondary, weights = filter_secondary_weights( |
| 270 | + all_weights, [self.secondary_weight_prefix]) |
| 271 | + |
| 272 | + super().load_weights(weights) |
| 273 | + |
| 274 | + params_dict = dict(self.named_parameters()) |
| 275 | + |
| 276 | + for name, loaded_weight in secondary: |
| 277 | + if name.startswith(self.secondary_weight_prefix): |
| 278 | + param = params_dict[name] |
| 279 | + weight_loader = getattr(param, "weight_loader", |
| 280 | + default_weight_loader) |
| 281 | + weight_loader(param, loaded_weight) |
| 282 | + |
| 283 | + |
153 | 284 | @default_pooling_type("CLS") |
154 | 285 | class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): |
155 | 286 | """A model that uses Roberta to provide embedding functionalities. |
|
0 commit comments