diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5aa9d0a770cfa1..0e322e0557fdb4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,6 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from threading import Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union from zipfile import is_zipfile @@ -3208,39 +3207,9 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True - - if resolved_archive_file is not None: - if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: - # If the PyTorch file was found, check if there is a safetensors file on the repository - # If there is no safetensors file on the repositories, start an auto conversion - safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - } - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "resume_download": resume_download, - "local_files_only": local_files_only, - "user_agent": user_agent, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - **has_file_kwargs, - } - if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - Thread( - target=auto_conversion, - args=(pretrained_model_name_or_path,), - kwargs=cached_file_kwargs, - name="Thread-autoconversion", - ).start() - else: - # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. - # We try those to give a helpful error message. + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. has_file_kwargs = { "revision": revision, "proxies": proxies, diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 57f0f11dbb8a06..1f277c7504561f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -20,7 +20,6 @@ import os.path import sys import tempfile -import threading import unittest import unittest.mock as mock import uuid @@ -1429,7 +1428,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): bot_opened_pr_title = None for discussion in discussions: - if discussion.author == "SFconvertbot": + if discussion.author == "SFconvertBot": bot_opened_pr = True bot_opened_pr_title = discussion.title @@ -1452,51 +1451,6 @@ def test_safetensors_on_the_fly_specific_revision(self): with self.assertRaises(EnvironmentError): BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") - def test_absence_of_safetensors_triggers_conversion(self): - config = BertConfig( - vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 - ) - initial_model = BertModel(config) - - # Push a model on `main` - initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) - - # Download the model that doesn't have safetensors - BertModel.from_pretrained(self.repo_name, token=self.token) - - for thread in threading.enumerate(): - if thread.name == "Thread-autoconversion": - thread.join(timeout=10) - - with self.subTest("PR was open with the safetensors account"): - discussions = self.api.get_repo_discussions(self.repo_name) - - bot_opened_pr = None - bot_opened_pr_title = None - - for discussion in discussions: - if discussion.author == "SFconvertbot": - bot_opened_pr = True - bot_opened_pr_title = discussion.title - - self.assertTrue(bot_opened_pr) - self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") - - @mock.patch("transformers.safetensors_conversion.spawn_conversion") - def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): - spawn_conversion_mock.side_effect = HTTPError() - - config = BertConfig( - vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 - ) - initial_model = BertModel(config) - - # Push a model on `main` - initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) - - # The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread - BertModel.from_pretrained(self.repo_name, token=self.token) - @require_torch @is_staging_test