Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model versioning #8324

Merged
merged 18 commits into from
Nov 10, 2020
Merged
Prev Previous commit
Next Next commit
Add doc + pass kwarg everywhere
  • Loading branch information
julien-c committed Nov 6, 2020
commit fe854127821e2edb80328ccedc1efca82d1e5d1f
4 changes: 4 additions & 0 deletions src/transformers/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Pretr
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.

Expand Down
21 changes: 10 additions & 11 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)

# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
Expand All @@ -120,6 +121,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
Expand All @@ -130,7 +132,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME)
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)

# redirect to the cache, if necessary
try:
Expand All @@ -142,16 +144,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
else:
msg = (
f"Model name '{pretrained_model_name_or_path}' "
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
f"We assumed '{archive_file}' was a path or url to model weight files but "
"couldn't find any such file at this path or url."
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)

if resolved_archive_file == archive_file:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Expand Down Expand Up @@ -613,6 +617,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)

# Load config if we don't provide a configuration
Expand All @@ -627,6 +632,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
Expand Down Expand Up @@ -655,6 +661,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
revision=revision,
mirror=mirror,
)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Expand Down Expand Up @@ -869,6 +873,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__`` method.
kwargs (additional keyword arguments, `optional`):
Expand Down Expand Up @@ -1551,6 +1555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)

s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
Expand Down Expand Up @@ -1602,7 +1607,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, filename=file_name, mirror=None)
full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
)

vocab_files[file_id] = full_file_name

Expand Down