Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"tensorflow_datasets",
"tqdm",
"transformers<=4.57.1", # Note: transformers==4.57.2 has a bug, see more information at https://github.com/google/tunix/issues/794
"tenacity",
]

[project.optional-dependencies]
Expand Down
29 changes: 25 additions & 4 deletions tunix/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import jax.numpy as jnp
import numpy as np
import qwix
import tenacity
from tunix.rl import reshard
from tunix.utils import env_utils

Expand Down Expand Up @@ -268,16 +269,36 @@ def __call__(self, *args, **kwargs):
return score


@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def safe_list_files(repo_id):
return huggingface_hub.list_repo_files(repo_id)


@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def safe_download(repo_id, filename, local_dir):
return huggingface_hub.hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=local_dir
)


def download_from_huggingface(repo_id: str, model_path: str):
"""Download checkpoint files from huggingface."""
print('Make sure you logged in to the huggingface cli.')
all_files = huggingface_hub.list_repo_files(repo_id)

all_files = safe_list_files(repo_id)
filtered_files = [f for f in all_files if not f.startswith('original/')]

for filename in filtered_files:
huggingface_hub.hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=model_path
)

safe_download(repo_id=repo_id, filename=filename, local_dir=model_path)
print(f'Downloaded {filtered_files} to: {model_path}')


Expand Down