Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HfFileSystem] Support quoted revisions in path #1888

Merged
merged 7 commits into from
Dec 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests
  • Loading branch information
lhoestq committed Dec 5, 2023
commit f9487d4541768a6548176ab749634e3ffac827ec
56 changes: 45 additions & 11 deletions tests/test_hf_file_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import datetime
import unittest
from typing import Optional
Expand Down Expand Up @@ -93,10 +94,34 @@ def test_glob(self):
sorted(self.hffs.glob(self.hf_path + "@main" + "/*")),
sorted([self.hf_path + "@main" + "/.gitattributes", self.hf_path + "@main" + "/data"]),
)
self.assertEqual(
self.hffs.glob(self.hf_path + "@refs%2Fpr%2F1" + "/data/*"),
[self.hf_path + "@refs%2Fpr%2F1" + "/data/binary_data_for_pr.bin"],
)
self.assertEqual(
self.hffs.glob(self.hf_path + "@refs/pr/1" + "/data/*"),
[self.hf_path + "@refs/pr/1" + "/data/binary_data_for_pr.bin"],
)
self.assertEqual(
self.hffs.glob(self.hf_path + "/data/*", revision="refs/pr/1"),
[self.hf_path + "@refs/pr/1" + "/data/binary_data_for_pr.bin"],
)

self.assertIsNone(
self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"]
) # no detail -> no last_commit in cache

files = self.hffs.glob(self.hf_path + "@main" + "/*", detail=True, expand_info=False)
self.assertIsInstance(files, dict)
self.assertEqual(len(files), 2)
keys = sorted(files)
self.assertTrue(
files[keys[0]]["name"].endswith("/.gitattributes") and files[keys[1]]["name"].endswith("/data")
)
self.assertIsNone(
self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"]
) # detail but no expand info -> no last_commit in cache

files = self.hffs.glob(self.hf_path + "@main" + "/*", detail=True)
self.assertIsInstance(files, dict)
self.assertEqual(len(files), 2)
Expand Down Expand Up @@ -247,27 +272,32 @@ def test_list_data_directory_with_revision(self):
files = self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data")

for test_name, files in {
"rev_in_path": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data"),
"quoted_rev_in_path": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data"),
"rev_in_path": self.hffs.ls(self.hf_path + "@refs/pr/1" + "/data"),
"rev_as_arg": self.hffs.ls(self.hf_path + "/data", revision="refs/pr/1"),
"rev_in_path_and_as_arg": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data", revision="refs/pr/1"),
"quoted_rev_in_path_and_rev_as_arg": self.hffs.ls(
self.hf_path + "@refs%2Fpr%2F1" + "/data", revision="refs/pr/1"
),
}.items():
with self.subTest(test_name):
self.assertEqual(len(files), 1) # only one file in PR
self.assertEqual(files[0]["type"], "file")
self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file
if "quoted_rev_in_path" in test_name:
self.assertIn("@refs%2Fpr%2F1", files[0]["name"])
elif "rev_in_path" in test_name:
self.assertIn("@refs/pr/1", files[0]["name"])

def test_list_root_directory_no_revision_no_detail_then_with_detail(self):
files = self.hffs.ls(self.hf_path, detail=False)
self.assertEqual(len(files), 2)
self.assertTrue(files[0].endswith("/data") and files[1].endswith("/.gitattributes"))
self.assertIsNone(
self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"]
) # no detail -> no last_commit in cache
self.assertIsNone(self.hffs.dircache[self.hf_path][0]["last_commit"]) # no detail -> no last_commit in cache

files = self.hffs.ls(self.hf_path, detail=True)
self.assertEqual(len(files), 2)
self.assertTrue(files[0]["name"].endswith("/data") and files[1]["name"].endswith("/.gitattributes"))
self.assertIsNotNone(self.hffs.dircache[self.hf_path + "@main"][0]["last_commit"])
self.assertIsNotNone(self.hffs.dircache[self.hf_path][0]["last_commit"])

def test_find_root_directory_no_revision(self):
files = self.hffs.find(self.hf_path, detail=False)
Expand Down Expand Up @@ -310,14 +340,14 @@ def test_find_root_directory_no_revision_with_incomplete_cache(self):
repo_type="dataset",
)

files = self.hffs.find(self.hf_path, detail=True)
files = copy.deepcopy(self.hffs.find(self.hf_path, detail=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need copy.deepcopy here?

Copy link
Member Author

@lhoestq lhoestq Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next lines we .pop() the cache, which also modifies the files dictionary and make the test fail - maybe I can fix find to not have to do that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a deep copy? Is it because files is a reference to an internal object of self.hffs.find? If that's the case it would be a bit worrying 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can fix find to return a copy of the cache instead of the cache content itself (that users could modify)


# some directories not in cache
self.hffs.dircache.pop(self.hf_path + "@main/data/sub_data")
self.hffs.dircache.pop(self.hf_path + "/data/sub_data")
# some files not expanded
self.hffs.dircache[self.hf_path + "@main/data"][1]["last_commit"] = None

self.assertEqual(self.hffs.find(self.hf_path, detail=True), files)
self.hffs.dircache[self.hf_path + "/data"][1]["last_commit"] = None
out = self.hffs.find(self.hf_path, detail=True)
self.assertEqual(out, files)

def test_find_data_file_no_revision(self):
files = self.hffs.find(self.hf_path + "/data/text_data.txt", detail=False)
Expand Down Expand Up @@ -360,8 +390,10 @@ def test_find_data_file_no_revision(self):
"refs/convert/parquet",
),
("gpt2@refs/pr/2", None, "model", "gpt2", "refs/pr/2"),
("gpt2@refs%2Fpr%2F2", None, "model", "gpt2", "refs/pr/2"),
("hf://username/my_model@refs/pr/10", None, "model", "username/my_model", "refs/pr/10"),
("hf://username/my_model@refs/pr/10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"),
("hf://username/my_model@refs%2Fpr%2F10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"),
],
)
def test_resolve_path(
Expand All @@ -383,6 +415,8 @@ def test_resolve_path(
resolved_path.revision,
resolved_path.path_in_repo,
) == (repo_type, repo_id, resolved_revision, path_in_repo)
if "@" in path:
assert resolved_path._revision_in_path in path


@pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"])
Expand Down
Loading