Skip to content

Commit

Permalink
CLI: use hub's create_commit (huggingface#17755)
Browse files Browse the repository at this point in the history
* use create_commit

* better commit message and description

* touch setup.py to trigger cache update

* add hub version gating
  • Loading branch information
gante authored Jun 22, 2022
1 parent c366ce1 commit 0d0c392
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/add-model-like.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v3-tests_model_like-${{ hashFiles('setup.py') }}
key: v4-tests_model_like-${{ hashFiles('setup.py') }}

- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/model-templates.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v3-tests_templates-${{ hashFiles('setup.py') }}
key: v4-tests_templates-${{ hashFiles('setup.py') }}

- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update_metdata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v2-metadata-${{ hashFiles('setup.py') }}
key: v3-metadata-${{ hashFiles('setup.py') }}

- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'
Expand Down
81 changes: 52 additions & 29 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import numpy as np
from datasets import load_dataset
from packaging import version

from huggingface_hub import Repository, upload_file
import huggingface_hub

from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging
Expand All @@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)
return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
)


class PTtoTFCommand(BaseTransformersCLICommand):
Expand Down Expand Up @@ -89,6 +92,12 @@ def register_subcommand(parser: ArgumentParser):
action="store_true",
help="Optional flag to push the weights directly to `main` (requires permissions)",
)
train_parser.add_argument(
"--extra-commit-description",
type=str,
default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
)
train_parser.set_defaults(func=convert_command_factory)

@staticmethod
Expand Down Expand Up @@ -134,13 +143,23 @@ def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):

return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
def __init__(
self,
model_name: str,
local_dir: str,
new_weights: bool,
no_pr: bool,
push: bool,
extra_commit_description: str,
*args
):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push
self._extra_commit_description = extra_commit_description

def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
Expand Down Expand Up @@ -170,10 +189,17 @@ def get_image_inputs(self):
return pt_input, tf_input

def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"):
raise ImportError(
"The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub"
" installation."
)
else:
from huggingface_hub import Repository, create_commit
from huggingface_hub._commit_api import CommitOperationAdd

# Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit

# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
Expand Down Expand Up @@ -240,32 +266,29 @@ def run(self):
)
)

commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
if self._push:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Add TF weights")
repo.git_commit(commit_message)
repo.git_push(blocking=True) # this prints a progress bar with the upload
self._logger.warn(f"TF weights pushed into {self._model_name}")
elif not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)
self._logger.warn("Uploading the weights into a new PR...")
commit_descrition = (
"Model converted by the [`transformers`' `pt_to_tf`"
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
)
if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description
hub_pr_url = create_commit(
repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
commit_message=commit_message,
commit_description=commit_descrition,
repo_type="model",
create_pr=True,
)
self._logger.warn(f"PR open in {hub_pr_url}")

0 comments on commit 0d0c392

Please sign in to comment.