Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Simplify save_to_hub, remove git dependency, add 'token' argument #2376

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ nr_*/
/docs/make.bat
/docs/Makefile
/examples/training/quora_duplicate_questions/quora-IR-dataset/
build
build

htmlcov
.coverage
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
transformers>=4.6.0,<5.0.0
tokenizers>=0.10.3
transformers>=4.32.0,<5.0.0
tqdm
torch>=1.6.0
numpy
scikit-learn
scipy
nltk
sentencepiece
huggingface-hub
huggingface-hub>=0.15.1
Pillow
116 changes: 44 additions & 72 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from numpy import ndarray
import transformers
from huggingface_hub import HfApi, HfFolder, Repository
from huggingface_hub import HfApi
import torch
from torch import nn, Tensor, device
from torch.optim import Optimizer
Expand Down Expand Up @@ -450,8 +450,9 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_
fOut.write(model_card.strip())

def save_to_hub(self,
repo_name: str,
repo_id: str,
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
organization: Optional[str] = None,
token: Optional[str] = None,
private: Optional[bool] = None,
commit_message: str = "Add new SentenceTransformer model.",
local_model_path: Optional[str] = None,
Expand All @@ -461,90 +462,61 @@ def save_to_hub(self,
"""
Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

:param repo_name: Repository name for your model in the Hub.
:param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization).
:param repo_id: Repository name for your model in the Hub, including the user or organization.
:param token: An authentication token (See https://huggingface.co/settings/token)
:param private: Set to true, for hosting a prive model
:param commit_message: Message to commit while pushing.
:param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
:param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
:param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card
:param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
:return: The url of the commit of your model in the given repository.
"""
token = HfFolder.get_token()
if token is None:
raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.")
:param organization: Deprecated. Organization in which you want to push your model or tokenizer (you must be a member of this organization).

if '/' in repo_name:
splits = repo_name.split('/', maxsplit=1)
if organization is None or organization == splits[0]:
organization = splits[0]
repo_name = splits[1]
:return: The url of the commit of your model in the repository on the Hugging Face Hub.
"""
if organization:
if "/" not in repo_id:
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead."
)
repo_id = f"{organization}/{repo_id}"
elif repo_id.split("/")[0] != organization:
raise ValueError("Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`.")
else:
raise ValueError("You passed and invalid repository name: {}.".format(repo_name))
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"{repo_id}\"` instead."
)

endpoint = "https://huggingface.co"
repo_id = repo_name
if organization:
repo_id = f"{organization}/{repo_id}"
repo_url = HfApi(endpoint=endpoint).create_repo(
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
private=private,
repo_type=None,
exist_ok=exist_ok,
)
if local_model_path:
folder_url = api.upload_folder(
repo_id=repo_id,
token=token,
private=private,
repo_type=None,
exist_ok=exist_ok,
folder_path=local_model_path,
commit_message=commit_message
)
full_model_name = repo_url[len(endpoint)+1:].strip("/")

with tempfile.TemporaryDirectory() as tmp_dir:
# First create the repo (and clone its content if it's nonempty).
logger.info("Create repository and clone it if it exists")
repo = Repository(tmp_dir, clone_from=repo_url)

# If user provides local files, copy them.
if local_model_path:
copy_tree(local_model_path, tmp_dir)
else: # Else, save model directly into local repo.
else:
with tempfile.TemporaryDirectory() as tmp_dir:
create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md'))
self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets)

#Find files larger 5M and track with git-lfs
large_files = []
for root, dirs, files in os.walk(tmp_dir):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, tmp_dir)

if os.path.getsize(file_path) > (5 * 1024 * 1024):
large_files.append(rel_path)

if len(large_files) > 0:
logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
repo.lfs_track(large_files)

logger.info("Push model to the hub. This might take a while")
push_return = repo.push_to_hub(commit_message=commit_message)

def on_rm_error(func, path, exc_info):
# path contains the path of the file that couldn't be removed
# let's just assume that it's read-only and unlink it.
try:
os.chmod(path, stat.S_IWRITE)
os.unlink(path)
except:
pass

# Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted
# Hence, try to set write permissions on error
try:
for f in os.listdir(tmp_dir):
shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error)
except Exception as e:
logger.warning("Error when deleting temp folder: {}".format(str(e)))
pass
self.save(tmp_dir, model_name=repo_url.repo_id, create_model_card=create_model_card, train_datasets=train_datasets)
folder_url = api.upload_folder(
repo_id=repo_id,
folder_path=tmp_dir,
commit_message=commit_message
)

refs = api.list_repo_refs(repo_id=repo_id)
for branch in refs.branches:
if branch.name == "main":
return f"https://huggingface.co/{repo_id}/commit/{branch.target_commit}"
# This isn't expected to ever be reached.
return folder_url

return push_return

def smart_batching_collate(self, batch):
"""
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
packages=find_packages(),
python_requires=">=3.8.0",
install_requires=[
'transformers>=4.6.0,<5.0.0',
'transformers>=4.32.0,<5.0.0',
'tqdm',
'torch>=1.6.0',
'numpy',
'scikit-learn',
'scipy',
'nltk',
'sentencepiece',
'huggingface-hub>=0.4.0',
'huggingface-hub>=0.15.1',
'Pillow'
],
classifiers=[
Expand Down
166 changes: 117 additions & 49 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,132 @@
"""


import logging
from pathlib import Path
import tempfile
import pytest

from huggingface_hub import HfApi, RepoUrl, GitRefs, GitRefInfo
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
import unittest


class TestSentenceTransformer(unittest.TestCase):
def test_load_with_safetensors(self):
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)

# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.")

with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])

# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.")

sentences = ["This is a test sentence", "This is another test sentence"]
self.assertTrue(
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",


def test_load_with_safetensors() -> None:
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)

# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 0 == len(pytorch_files), "PyTorch model file must not be downloaded."
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(safetensors_files), "Safetensors model file must be downloaded."

with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])

# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 1 == len(pytorch_files), "PyTorch model file must be downloaded."
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 0 == len(safetensors_files), "Safetensors model file must not be downloaded."

sentences = ["This is a test sentence", "This is another test sentence"]
assert torch.equal(
safetensors_model.encode(sentences, convert_to_tensor=True),
pytorch_model.encode(sentences, convert_to_tensor=True),
), "Ensure that Safetensors and PyTorch loaded models result in identical embeddings"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to() -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")

test_device = torch.device("cuda")
assert model.device.type == "cpu"
assert test_device.type == "cuda"

model.to(test_device)
assert model.device.type == "cuda", "The model device should have updated"

@unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to(self):
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")
model.encode("Test sentence")
assert model.device.type == "cuda", "Encoding shouldn't change the device"

test_device = torch.device("cuda")
self.assertEqual(model.device.type, "cpu")
self.assertEqual(test_device.type, "cuda")
assert model._target_device == model.device, "Prevent backwards compatibility failure for _target_device"
model._target_device = "cpu"
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."

model.to(test_device)
self.assertEqual(model.device.type, "cuda", msg="The model device should have updated")

model.encode("Test sentence")
self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device")
def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

mock_upload_folder_kwargs = {}

def mock_upload_folder(self, **kwargs):
nonlocal mock_upload_folder_kwargs
mock_upload_folder_kwargs = kwargs

def mock_list_repo_refs(self, repo_id=None, **kwargs):
try:
git_ref_info = GitRefInfo(name="main", ref="refs/heads/main", target_commit="123456")
except TypeError:
git_ref_info = GitRefInfo(dict(name="main", ref="refs/heads/main", targetCommit="123456"))
return GitRefs(branches=[git_ref_info], converts=[], tags=[])

monkeypatch.setattr(HfApi, "create_repo", mock_create_repo)
monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder)
monkeypatch.setattr(HfApi, "list_repo_refs", mock_list_repo_refs)

model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
url = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
mock_upload_folder_kwargs.clear()

with pytest.raises(
ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
):
model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated")

caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device")
model._target_device = "cpu"
self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.")
url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"