Skip to content

Commit

Permalink
[feat] Add lightning-fast StaticEmbedding module based on model2vec (
Browse files Browse the repository at this point in the history
…#2961)

* Add lightning-fast StaticEmbedding module based on model2vec

* Add explicit kwargs for StaticEmbedding.from_distillation
  • Loading branch information
tomaarsen authored Oct 8, 2024
1 parent 07ae865 commit 7855327
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/package_reference/sentence_transformer/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
.. autoclass:: sentence_transformers.models.CNN
.. autoclass:: sentence_transformers.models.LSTM
.. autoclass:: sentence_transformers.models.Normalize
.. autoclass:: sentence_transformers.models.StaticEmbedding
.. autoclass:: sentence_transformers.models.WeightedLayerPooling
.. autoclass:: sentence_transformers.models.WordEmbeddings
.. autoclass:: sentence_transformers.models.WordWeights
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
# Using cache:
else:
output = self.cache[self.idx]
output["token_embeddings"] = self.shrink(output["token_embeddings"])
if "token_embeddings" in output:
output["token_embeddings"] = self.shrink(output["token_embeddings"])
output["sentence_embedding"] = self.shrink(output["sentence_embedding"])
self.idx += 1
return output
Expand Down
5 changes: 4 additions & 1 deletion sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers.trainer_callback import TrainerControl, TrainerState

from sentence_transformers import __version__ as sentence_transformers_version
from sentence_transformers.models import Transformer
from sentence_transformers.models import StaticEmbedding, Transformer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.util import fullname, is_accelerate_available, is_datasets_available

Expand Down Expand Up @@ -753,6 +753,9 @@ def try_to_set_base_model(self) -> None:
for model_id in candidate_model_ids:
if self.set_base_model(model_id):
break
elif isinstance(self.model[0], StaticEmbedding):
if self.model[0].base_model:
self.set_base_model(self.model[0].base_model)

def format_eval_metrics(self) -> dict[str, Any]:
"""Format the evaluation metrics for the model card.
Expand Down
206 changes: 206 additions & 0 deletions sentence_transformers/models/StaticEmbedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from __future__ import annotations

import math
import os
from pathlib import Path

import numpy as np
import torch
from safetensors.torch import load_file as load_safetensors_file
from safetensors.torch import save_file as save_safetensors_file
from tokenizers import Tokenizer
from torch import nn
from transformers import PreTrainedTokenizerFast

from sentence_transformers.util import get_device_name


class StaticEmbedding(nn.Module):
def __init__(
self,
tokenizer: Tokenizer | PreTrainedTokenizerFast,
embedding_weights: np.array | torch.Tensor | None = None,
embedding_dim: int | None = None,
**kwargs,
) -> None:
"""
Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that
takes the mean of trained per-token embeddings to compute text embeddings.
Args:
tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer
from ``transformers`` or ``tokenizers``.
embedding_weights (np.array | torch.Tensor | None, optional): Pre-trained embedding weights.
Defaults to None.
embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights
is not provided. Defaults to None.
Example::
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer
# Pre-distilled embeddings:
static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
# or distill your own embeddings:
static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda")
# or start with randomized embeddings:
tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=512)
model = SentenceTransformer(modules=[static_embedding])
embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."])
similarity = model.similarity(embeddings[0], embeddings[1])
# tensor([[0.9177]]) (If you use the distilled bge-base)
Raises:
ValueError: If the tokenizer is not a fast tokenizer.
ValueError: If neither `embedding_weights` nor `embedding_dim` is provided.
"""
super().__init__()

if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = tokenizer._tokenizer
elif not isinstance(tokenizer, Tokenizer):
raise ValueError(
"The tokenizer must be fast (i.e. Rust-backed) to use this class. "
"Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer."
)

if embedding_weights is not None:
if isinstance(embedding_weights, np.ndarray):
embedding_weights = torch.from_numpy(embedding_weights)

self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False)
elif embedding_dim is not None:
self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim)
else:
raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.")

self.num_embeddings = self.embedding.num_embeddings
self.embedding_dim = self.embedding.embedding_dim

self.tokenizer: Tokenizer = tokenizer
self.tokenizer.no_padding()

# For the model card
self.base_model = kwargs.get("base_model", None)

def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]:
encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False)
encodings_ids = [encoding.ids for encoding in encodings]

offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
return {"input_ids": input_ids, "offsets": offsets}

def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"])
return features

def get_config_dict(self) -> dict[str, float]:
return {}

@property
def max_seq_length(self) -> int:
return math.inf

def get_sentence_embedding_dimension(self) -> int:
return self.embedding_dim

def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None:
if safe_serialization:
save_safetensors_file(self.state_dict(), os.path.join(save_dir, "model.safetensors"))
else:
torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
self.tokenizer.save(str(Path(save_dir) / "tokenizer.json"))

def load(load_dir: str, **kwargs) -> StaticEmbedding:
tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json"))
if os.path.exists(os.path.join(load_dir, "model.safetensors")):
weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors"))
else:
weights = torch.load(
os.path.join(load_dir, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
)
weights = weights["embedding.weight"]
return StaticEmbedding(tokenizer, embedding_weights=weights)

@classmethod
def from_distillation(
cls,
model_name: str,
vocabulary: list[str] | None = None,
device: str | None = None,
pca_dims: int | None = 256,
apply_zipf: bool = True,
use_subword: bool = True,
) -> StaticEmbedding:
"""
Creates a StaticEmbedding instance from a distillation process using the `model2vec` package.
Args:
model_name (str): The name of the model to distill.
vocabulary (list[str] | None, optional): A list of vocabulary words to use. Defaults to None.
device (str): The device to run the distillation on (e.g., 'cpu', 'cuda'). If not specified,
the strongest device is automatically detected. Defaults to None.
pca_dims (int | None, optional): The number of dimensions for PCA reduction. Defaults to 256.
apply_zipf (bool): Whether to apply Zipf's law during distillation. Defaults to True.
use_subword (bool): Whether to use subword tokenization. Defaults to True.
Returns:
StaticEmbedding: An instance of StaticEmbedding initialized with the distilled model's
tokenizer and embedding weights.
Raises:
ImportError: If the `model2vec` package is not installed.
"""

try:
from model2vec import distill
except ImportError:
raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")

device = get_device_name()
static_model = distill(
model_name,
vocabulary=vocabulary,
device=device,
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name)

@classmethod
def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding:
"""
Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model
and extracts the embedding weights and tokenizer to create a StaticEmbedding instance.
Args:
model_id_or_path (str): The identifier or path to the pre-trained model2vec model.
Returns:
StaticEmbedding: An instance of StaticEmbedding initialized with the tokenizer and embedding weights
the model2vec model.
Raises:
ImportError: If the `model2vec` package is not installed.
"""

try:
from model2vec import StaticModel
except ImportError:
raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")

static_model = StaticModel.from_pretrained(model_id_or_path)
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path)
2 changes: 2 additions & 0 deletions sentence_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from .LSTM import LSTM
from .Normalize import Normalize
from .Pooling import Pooling
from .StaticEmbedding import StaticEmbedding
from .Transformer import Transformer
from .WeightedLayerPooling import WeightedLayerPooling
from .WordEmbeddings import WordEmbeddings
from .WordWeights import WordWeights

__all__ = [
"Transformer",
"StaticEmbedding",
"Asym",
"BoW",
"CNN",
Expand Down

0 comments on commit 7855327

Please sign in to comment.