From 78553270abc74f44c1504db0e29f79591af6b697 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 8 Oct 2024 16:32:22 +0200 Subject: [PATCH] [`feat`] Add lightning-fast StaticEmbedding module based on model2vec (#2961) * Add lightning-fast StaticEmbedding module based on model2vec * Add explicit kwargs for StaticEmbedding.from_distillation --- .../sentence_transformer/models.md | 1 + .../losses/MatryoshkaLoss.py | 3 +- sentence_transformers/model_card.py | 5 +- .../models/StaticEmbedding.py | 206 ++++++++++++++++++ sentence_transformers/models/__init__.py | 2 + 5 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 sentence_transformers/models/StaticEmbedding.py diff --git a/docs/package_reference/sentence_transformer/models.md b/docs/package_reference/sentence_transformer/models.md index f54a7d11f..84796f916 100644 --- a/docs/package_reference/sentence_transformer/models.md +++ b/docs/package_reference/sentence_transformer/models.md @@ -15,6 +15,7 @@ .. autoclass:: sentence_transformers.models.CNN .. autoclass:: sentence_transformers.models.LSTM .. autoclass:: sentence_transformers.models.Normalize +.. autoclass:: sentence_transformers.models.StaticEmbedding .. autoclass:: sentence_transformers.models.WeightedLayerPooling .. autoclass:: sentence_transformers.models.WordEmbeddings .. autoclass:: sentence_transformers.models.WordWeights diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index e4a6dd851..e6a18aac0 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -44,7 +44,8 @@ def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]: # Using cache: else: output = self.cache[self.idx] - output["token_embeddings"] = self.shrink(output["token_embeddings"]) + if "token_embeddings" in output: + output["token_embeddings"] = self.shrink(output["token_embeddings"]) output["sentence_embedding"] = self.shrink(output["sentence_embedding"]) self.idx += 1 return output diff --git a/sentence_transformers/model_card.py b/sentence_transformers/model_card.py index 92eed253f..99da35d96 100644 --- a/sentence_transformers/model_card.py +++ b/sentence_transformers/model_card.py @@ -27,7 +27,7 @@ from transformers.trainer_callback import TrainerControl, TrainerState from sentence_transformers import __version__ as sentence_transformers_version -from sentence_transformers.models import Transformer +from sentence_transformers.models import StaticEmbedding, Transformer from sentence_transformers.training_args import SentenceTransformerTrainingArguments from sentence_transformers.util import fullname, is_accelerate_available, is_datasets_available @@ -753,6 +753,9 @@ def try_to_set_base_model(self) -> None: for model_id in candidate_model_ids: if self.set_base_model(model_id): break + elif isinstance(self.model[0], StaticEmbedding): + if self.model[0].base_model: + self.set_base_model(self.model[0].base_model) def format_eval_metrics(self) -> dict[str, Any]: """Format the evaluation metrics for the model card. diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py new file mode 100644 index 000000000..de69285b2 --- /dev/null +++ b/sentence_transformers/models/StaticEmbedding.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import math +import os +from pathlib import Path + +import numpy as np +import torch +from safetensors.torch import load_file as load_safetensors_file +from safetensors.torch import save_file as save_safetensors_file +from tokenizers import Tokenizer +from torch import nn +from transformers import PreTrainedTokenizerFast + +from sentence_transformers.util import get_device_name + + +class StaticEmbedding(nn.Module): + def __init__( + self, + tokenizer: Tokenizer | PreTrainedTokenizerFast, + embedding_weights: np.array | torch.Tensor | None = None, + embedding_dim: int | None = None, + **kwargs, + ) -> None: + """ + Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that + takes the mean of trained per-token embeddings to compute text embeddings. + + Args: + tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer + from ``transformers`` or ``tokenizers``. + embedding_weights (np.array | torch.Tensor | None, optional): Pre-trained embedding weights. + Defaults to None. + embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights + is not provided. Defaults to None. + + Example:: + + from sentence_transformers import SentenceTransformer + from sentence_transformers.models import StaticEmbedding + from tokenizers import Tokenizer + + # Pre-distilled embeddings: + static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") + # or distill your own embeddings: + static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda") + # or start with randomized embeddings: + tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base") + static_embedding = StaticEmbedding(tokenizer, embedding_dim=512) + + model = SentenceTransformer(modules=[static_embedding]) + + embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."]) + similarity = model.similarity(embeddings[0], embeddings[1]) + # tensor([[0.9177]]) (If you use the distilled bge-base) + + Raises: + ValueError: If the tokenizer is not a fast tokenizer. + ValueError: If neither `embedding_weights` nor `embedding_dim` is provided. + """ + super().__init__() + + if isinstance(tokenizer, PreTrainedTokenizerFast): + tokenizer = tokenizer._tokenizer + elif not isinstance(tokenizer, Tokenizer): + raise ValueError( + "The tokenizer must be fast (i.e. Rust-backed) to use this class. " + "Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer." + ) + + if embedding_weights is not None: + if isinstance(embedding_weights, np.ndarray): + embedding_weights = torch.from_numpy(embedding_weights) + + self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False) + elif embedding_dim is not None: + self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim) + else: + raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.") + + self.num_embeddings = self.embedding.num_embeddings + self.embedding_dim = self.embedding.embedding_dim + + self.tokenizer: Tokenizer = tokenizer + self.tokenizer.no_padding() + + # For the model card + self.base_model = kwargs.get("base_model", None) + + def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]: + encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) + encodings_ids = [encoding.ids for encoding in encodings] + + offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])) + input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long) + return {"input_ids": input_ids, "offsets": offsets} + + def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: + features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"]) + return features + + def get_config_dict(self) -> dict[str, float]: + return {} + + @property + def max_seq_length(self) -> int: + return math.inf + + def get_sentence_embedding_dimension(self) -> int: + return self.embedding_dim + + def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None: + if safe_serialization: + save_safetensors_file(self.state_dict(), os.path.join(save_dir, "model.safetensors")) + else: + torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.bin")) + self.tokenizer.save(str(Path(save_dir) / "tokenizer.json")) + + def load(load_dir: str, **kwargs) -> StaticEmbedding: + tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json")) + if os.path.exists(os.path.join(load_dir, "model.safetensors")): + weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors")) + else: + weights = torch.load( + os.path.join(load_dir, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True + ) + weights = weights["embedding.weight"] + return StaticEmbedding(tokenizer, embedding_weights=weights) + + @classmethod + def from_distillation( + cls, + model_name: str, + vocabulary: list[str] | None = None, + device: str | None = None, + pca_dims: int | None = 256, + apply_zipf: bool = True, + use_subword: bool = True, + ) -> StaticEmbedding: + """ + Creates a StaticEmbedding instance from a distillation process using the `model2vec` package. + + Args: + model_name (str): The name of the model to distill. + vocabulary (list[str] | None, optional): A list of vocabulary words to use. Defaults to None. + device (str): The device to run the distillation on (e.g., 'cpu', 'cuda'). If not specified, + the strongest device is automatically detected. Defaults to None. + pca_dims (int | None, optional): The number of dimensions for PCA reduction. Defaults to 256. + apply_zipf (bool): Whether to apply Zipf's law during distillation. Defaults to True. + use_subword (bool): Whether to use subword tokenization. Defaults to True. + + Returns: + StaticEmbedding: An instance of StaticEmbedding initialized with the distilled model's + tokenizer and embedding weights. + + Raises: + ImportError: If the `model2vec` package is not installed. + """ + + try: + from model2vec import distill + except ImportError: + raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`") + + device = get_device_name() + static_model = distill( + model_name, + vocabulary=vocabulary, + device=device, + pca_dims=pca_dims, + apply_zipf=apply_zipf, + use_subword=use_subword, + ) + embedding_weights = static_model.embedding.weight + tokenizer: Tokenizer = static_model.tokenizer + + return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name) + + @classmethod + def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding: + """ + Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model + and extracts the embedding weights and tokenizer to create a StaticEmbedding instance. + + Args: + model_id_or_path (str): The identifier or path to the pre-trained model2vec model. + + Returns: + StaticEmbedding: An instance of StaticEmbedding initialized with the tokenizer and embedding weights + the model2vec model. + + Raises: + ImportError: If the `model2vec` package is not installed. + """ + + try: + from model2vec import StaticModel + except ImportError: + raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`") + + static_model = StaticModel.from_pretrained(model_id_or_path) + embedding_weights = static_model.embedding.weight + tokenizer: Tokenizer = static_model.tokenizer + + return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path) diff --git a/sentence_transformers/models/__init__.py b/sentence_transformers/models/__init__.py index d9684310a..c0fcd8dbf 100644 --- a/sentence_transformers/models/__init__.py +++ b/sentence_transformers/models/__init__.py @@ -10,6 +10,7 @@ from .LSTM import LSTM from .Normalize import Normalize from .Pooling import Pooling +from .StaticEmbedding import StaticEmbedding from .Transformer import Transformer from .WeightedLayerPooling import WeightedLayerPooling from .WordEmbeddings import WordEmbeddings @@ -17,6 +18,7 @@ __all__ = [ "Transformer", + "StaticEmbedding", "Asym", "BoW", "CNN",