Skip to content

Commit

Permalink
Make Transformers use cache files when hf.co is down (huggingface#16362)
Browse files Browse the repository at this point in the history
* Make Transformers use cache files when hf.co is down

* Fix tests

* Was there a random circleCI failure?

* Isolate patches

* Style

* Comment out the failure since it doesn't fail anymore

* Better comment
  • Loading branch information
sgugger authored Mar 23, 2022
1 parent 8a69e02 commit c595b6e
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 35 deletions.
14 changes: 9 additions & 5 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,16 @@ def _get_config_dict(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory containing a {configuration_file} "
"file.\nCheckout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
"{configuration_file} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,14 @@ def get_feature_extractor_dict(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a "
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,16 @@ def from_pretrained(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,11 +1678,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,11 +1409,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
Expand Down
11 changes: 3 additions & 8 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import numpy as np
from packaging import version

from requests import HTTPError

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
Expand Down Expand Up @@ -1751,12 +1749,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None

except HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(f"Connection problem to access {file_path}.")
resolved_vocab_files[file_id] = None
else:
raise err
except ValueError:
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
resolved_vocab_files[file_id] = None

if len(unresolved_files) > 0:
logger.info(
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,17 @@ def get_from_cache(
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
except (
requests.exceptions.SSLError,
requests.exceptions.ProxyError,
RepositoryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError,
):
# Actually raise for those subclasses of ConnectionError
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
raise
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# Otherwise, our Internet connection is down.
# etag is None
pass
Expand Down
18 changes: 17 additions & 1 deletion tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import tempfile
import unittest
import unittest.mock
import unittest.mock as mock
from pathlib import Path

from huggingface_hub import Repository, delete_repo, login
Expand Down Expand Up @@ -304,6 +304,22 @@ def test_config_common_kwargs_is_complete(self):
f"pick another value for them: {', '.join(keys_with_defaults)}."
)

def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError

# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")

# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()


class ConfigurationVersioningTest(unittest.TestCase):
def test_local_versioning(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_feature_extraction_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path

from huggingface_hub import Repository, delete_repo, login
Expand Down Expand Up @@ -116,6 +117,23 @@ def test_init_without_params(self):
self.assertIsNotNone(feat_extract)


class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError

# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()


@is_staging_test
class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod
Expand Down
17 changes: 17 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import tempfile
import unittest
import unittest.mock as mock
import warnings
from pathlib import Path
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -2272,6 +2273,22 @@ def test_no_super_init_config_and_model(self):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError

# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()


@require_torch
@is_staging_test
Expand Down
17 changes: 17 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import random
import tempfile
import unittest
import unittest.mock as mock
from importlib import import_module
from typing import List, Tuple

Expand Down Expand Up @@ -1555,6 +1556,22 @@ def test_top_k_top_p_filtering(self):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)

def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError

# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()

# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import tempfile
import unittest
import unittest.mock as mock
from collections import OrderedDict
from itertools import takewhile
from pathlib import Path
Expand Down Expand Up @@ -3742,6 +3743,24 @@ def test_save_slow_from_fast_and_reload_fast(self):
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)


class TokenizerUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError

# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")

# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()


@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled
# next emulate no network
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]

# should normally fail as it will fail to lookup the model files w/o the network
env["TRANSFORMERS_OFFLINE"] = "0"
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr)
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
# env["TRANSFORMERS_OFFLINE"] = "0"
# result = subprocess.run(cmd, env=env, check=False, capture_output=True)
# self.assertEqual(result.returncode, 1, result.stderr)

# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"
Expand Down

0 comments on commit c595b6e

Please sign in to comment.