Skip to content

Commit

Permalink
[refactor] model loading - no more unnecessary file downloads (#2345)
Browse files Browse the repository at this point in the history
* Refactor model loading: no full repo download

* Add simple test regarding efficient loading

* Replace use_auth_token with token in docstring

Deprecated arguments are not listed in docstrings

* Prevent crash if internet is down

* Use load_file_path in "is_sbert_model"
  • Loading branch information
tomaarsen committed Dec 12, 2023
1 parent 33915ab commit 331549c
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 115 deletions.
78 changes: 46 additions & 32 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import stat
from collections import OrderedDict
import warnings
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal
import numpy as np
from numpy import ndarray
Expand All @@ -22,7 +23,7 @@

from . import __MODEL_HUB_ORGANIZATION__
from .evaluation import SentenceEvaluator
from .util import import_from_string, batch_to_device, fullname, snapshot_download
from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path
from .models import Transformer, Pooling
from .model_card_templates import ModelCardTemplate
from . import __version__
Expand Down Expand Up @@ -59,17 +60,27 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU
can be used.
:param cache_folder: Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.
:param use_auth_token: Hugging Face authentication token to download private models.
:param token: Hugging Face authentication token to download private models.
"""
def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

if cache_folder is None:
cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
Expand All @@ -86,13 +97,10 @@ def __init__(self, model_name_or_path: Optional[str] = None,
if model_name_or_path is not None and model_name_or_path != "":
logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))

#Old models that don't belong to any organization
# Old models that don't belong to any organization
basic_transformer_models = ['albert-base-v1', 'albert-base-v2', 'albert-large-v1', 'albert-large-v2', 'albert-xlarge-v1', 'albert-xlarge-v2', 'albert-xxlarge-v1', 'albert-xxlarge-v2', 'bert-base-cased-finetuned-mrpc', 'bert-base-cased', 'bert-base-chinese', 'bert-base-german-cased', 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'bert-base-multilingual-cased', 'bert-base-multilingual-uncased', 'bert-base-uncased', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking', 'bert-large-cased', 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking', 'bert-large-uncased', 'camembert-base', 'ctrl', 'distilbert-base-cased-distilled-squad', 'distilbert-base-cased', 'distilbert-base-german-cased', 'distilbert-base-multilingual-cased', 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-finetuned-sst-2-english', 'distilbert-base-uncased', 'distilgpt2', 'distilroberta-base', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'gpt2', 'openai-gpt', 'roberta-base-openai-detector', 'roberta-base', 'roberta-large-mnli', 'roberta-large-openai-detector', 'roberta-large', 't5-11b', 't5-3b', 't5-base', 't5-large', 't5-small', 'transfo-xl-wt103', 'xlm-clm-ende-1024', 'xlm-clm-enfr-1024', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', 'xlm-roberta-base', 'xlm-roberta-large-finetuned-conll02-dutch', 'xlm-roberta-large-finetuned-conll02-spanish', 'xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large-finetuned-conll03-german', 'xlm-roberta-large', 'xlnet-base-cased', 'xlnet-large-cased']

if os.path.exists(model_name_or_path):
#Load from path
model_path = model_name_or_path
else:
if not os.path.exists(model_name_or_path):
#Not a path, load from hub
if '\\' in model_name_or_path or model_name_or_path.count('/') > 1:
raise ValueError("Path {} not found".format(model_name_or_path))
Expand All @@ -101,21 +109,10 @@ def __init__(self, model_name_or_path: Optional[str] = None,
# A model from sentence-transformers
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path

model_path = os.path.join(cache_folder, model_name_or_path.replace("/", "_"))

if not os.path.exists(os.path.join(model_path, 'modules.json')):
# Download from hub with caching
snapshot_download(model_name_or_path,
cache_dir=cache_folder,
library_name='sentence-transformers',
library_version=__version__,
ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'],
use_auth_token=use_auth_token)

if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_path)
else: #Load with AutoModel
modules = self._load_auto_model(model_path)
if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder):
modules = self._load_sbert_model(model_name_or_path, token=token, cache_folder=cache_folder)
else:
modules = self._load_auto_model(model_name_or_path, token=token, cache_folder=cache_folder)

if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
Expand Down Expand Up @@ -823,46 +820,63 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
shutil.rmtree(old_checkpoints[0]['path'])


def _load_auto_model(self, model_name_or_path):
def _load_auto_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
"""
logger.warning("No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path))
transformer_model = Transformer(model_name_or_path)
transformer_model = Transformer(model_name_or_path, cache_dir=cache_folder, model_args={"token": token})
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean')
return [transformer_model, pooling_model]

def _load_sbert_model(self, model_path):
def _load_sbert_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
"""
Loads a full sentence-transformers model
"""
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(model_path, 'config_sentence_transformers.json')
if os.path.exists(config_sentence_transformers_json_path):
config_sentence_transformers_json_path = load_file_path(model_name_or_path, 'config_sentence_transformers.json', token=token, cache_folder=cache_folder)
if config_sentence_transformers_json_path is not None:
with open(config_sentence_transformers_json_path) as fIn:
self._model_config = json.load(fIn)

if '__version__' in self._model_config and 'sentence_transformers' in self._model_config['__version__'] and self._model_config['__version__']['sentence_transformers'] > __version__:
logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(self._model_config['__version__']['sentence_transformers'], __version__))

# Check if a readme exists
model_card_path = os.path.join(model_path, 'README.md')
if os.path.exists(model_card_path):
model_card_path = load_file_path(model_name_or_path, 'README.md', token=token, cache_folder=cache_folder)
if model_card_path is not None:
try:
with open(model_card_path, encoding='utf8') as fIn:
self._model_card_text = fIn.read()
except:
pass

# Load the modules of sentence transformer
modules_json_path = os.path.join(model_path, 'modules.json')
modules_json_path = load_file_path(model_name_or_path, 'modules.json', token=token, cache_folder=cache_folder)
with open(modules_json_path) as fIn:
modules_config = json.load(fIn)

modules = OrderedDict()
for module_config in modules_config:
module_class = import_from_string(module_config['type'])
module = module_class.load(os.path.join(model_path, module_config['path']))
# For Transformer, don't load the full directory, rely on `transformers` instead
# But, do load the config file first.
if module_class == Transformer and module_config['path'] == "":
kwargs = {}
for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']:
config_path = load_file_path(model_name_or_path, config_name, token=token, cache_folder=cache_folder)
if config_path is not None:
with open(config_path) as fIn:
kwargs = json.load(fIn)
break
if "model_args" in kwargs:
kwargs["model_args"]["token"] = token
else:
kwargs["model_args"] = {"token": token}
module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
else:
module_path = load_dir_path(model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder)
module = module_class.load(module_path)
modules[module_config['name']] = module

return modules
Expand Down
142 changes: 59 additions & 83 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from typing import Dict, Optional, Union
from pathlib import Path

import huggingface_hub
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub import HfApi, hf_hub_url, cached_download, HfFolder
from huggingface_hub import snapshot_download, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
import fnmatch
from packaging import version
import heapq
Expand Down Expand Up @@ -424,86 +424,62 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch
######################



def snapshot_download(
repo_id: str,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
ignore_files: Optional[List[str]] = None,
use_auth_token: Union[bool, str, None] = None
) -> str:
class disabled_tqdm(tqdm):
"""
Method derived from huggingface_hub.
Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns
Class to override `disable` argument in case progress bars are globally disabled.
Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324.
"""
if cache_dir is None:
cache_dir = HUGGINGFACE_HUB_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

_api = HfApi()

token = None
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token:
token = HfFolder.get_token()

model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)

storage_folder = os.path.join(
cache_dir, repo_id.replace("/", "_")
)

all_files = model_info.siblings
#Download modules.json as the last file
for idx, repofile in enumerate(all_files):
if repofile.rfilename == "modules.json":
del all_files[idx]
all_files.append(repofile)
break

for model_file in all_files:
if ignore_files is not None:
skip_download = False
for pattern in ignore_files:
if fnmatch.fnmatch(model_file.rfilename, pattern):
skip_download = True
break

if skip_download:
continue

url = hf_hub_url(
repo_id, filename=model_file.rfilename, revision=model_info.sha
)
relative_filepath = os.path.join(*model_file.rfilename.split("/"))

# Create potential nested dir
nested_dirname = os.path.dirname(
os.path.join(storage_folder, relative_filepath)
)
os.makedirs(nested_dirname, exist_ok=True)

cached_download_args = {'url': url,
'cache_dir': storage_folder,
'force_filename': relative_filepath,
'library_name': library_name,
'library_version': library_version,
'user_agent': user_agent,
'use_auth_token': use_auth_token}

if version.parse(huggingface_hub.__version__) >= version.parse("0.8.1"):
# huggingface_hub v0.8.1 introduces a new cache layout. We sill use a manual layout
# And need to pass legacy_cache_layout=True to avoid that a warning will be printed
cached_download_args['legacy_cache_layout'] = True

path = cached_download(**cached_download_args)

if os.path.exists(path + ".lock"):
os.remove(path + ".lock")

return storage_folder

def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs)

def __delattr__(self, attr: str) -> None:
"""Fix for https://github.com/huggingface/huggingface_hub/issues/1603"""
try:
super().__delattr__(attr)
except AttributeError:
if attr != "_lock":
raise


def is_sentence_transformer_model(model_name_or_path: str, token: Optional[Union[bool, str]] = None, cache_folder: Optional[str] = None) -> bool:
return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder))


def load_file_path(model_name_or_path: str, filename: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]) -> Optional[str]:
# If file is local
file_path = os.path.join(model_name_or_path, filename)
if os.path.exists(file_path):
return file_path

# If file is remote
try:
return hf_hub_download(model_name_or_path, filename=filename, library_name="sentence-transformers", token=token, cache_dir=cache_folder)
except Exception:
return


def load_dir_path(model_name_or_path: str, directory: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]) -> Optional[str]:
# If file is local
dir_path = os.path.join(model_name_or_path, directory)
if os.path.exists(dir_path):
return dir_path

download_kwargs = {
"repo_id": model_name_or_path,
"allow_patterns":f"{directory}/**",
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
# Try to download from the remote
try:
repo_path = snapshot_download(**download_kwargs)
except Exception:
# Otherwise, try local (i.e. cache) only
download_kwargs["local_files_only"] = True
repo_path = snapshot_download(**download_kwargs)
return os.path.join(repo_path, directory)
47 changes: 47 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Tests general behaviour of the SentenceTransformer class
"""

from pathlib import Path
import tempfile

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
import unittest


class TestSentenceTransformer(unittest.TestCase):
def test_load_with_safetensors(self):
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)

# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.")

with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])

# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.")

sentences = ["This is a test sentence", "This is another test sentence"]
self.assertTrue(
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",
)

0 comments on commit 331549c

Please sign in to comment.