Skip to content

Commit

Permalink
Fix sorting properties (#2655)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina authored Nov 5, 2024
1 parent 4011b5a commit 9fb6a8c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
42 changes: 33 additions & 9 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,8 +1798,8 @@ def list_models(
A tuple of two ints or floats representing a minimum and maximum
carbon footprint to filter the resulting models with in grams.
sort (`Literal["last_modified"]` or `str`, *optional*):
The key with which to sort the resulting models. Possible values
are the properties of the [`huggingface_hub.hf_api.ModelInfo`] class.
The key with which to sort the resulting models. Possible values are "last_modified", "trending_score",
"created_at", "downloads" and "likes".
direction (`Literal[-1]` or `int`, *optional*):
Direction in which to sort. The value `-1` sorts by descending
order while all other values sort by ascending order.
Expand Down Expand Up @@ -1911,7 +1911,15 @@ def list_models(
if len(search_list) > 0:
params["search"] = search_list
if sort is not None:
params["sort"] = "lastModified" if sort == "last_modified" else sort
params["sort"] = (
"lastModified"
if sort == "last_modified"
else "trendingScore"
if sort == "trending_score"
else "createdAt"
if sort == "created_at"
else sort
)
if direction is not None:
params["direction"] = direction
if limit is not None:
Expand Down Expand Up @@ -2010,8 +2018,8 @@ def list_datasets(
search (`str`, *optional*):
A string that will be contained in the returned datasets.
sort (`Literal["last_modified"]` or `str`, *optional*):
The key with which to sort the resulting datasets. Possible
values are the properties of the [`huggingface_hub.hf_api.DatasetInfo`] class.
The key with which to sort the resulting models. Possible values are "last_modified", "trending_score",
"created_at", "downloads" and "likes".
direction (`Literal[-1]` or `int`, *optional*):
Direction in which to sort. The value `-1` sorts by descending
order while all other values sort by ascending order.
Expand Down Expand Up @@ -2121,7 +2129,15 @@ def list_datasets(
if len(search_list) > 0:
params["search"] = search_list
if sort is not None:
params["sort"] = "lastModified" if sort == "last_modified" else sort
params["sort"] = (
"lastModified"
if sort == "last_modified"
else "trendingScore"
if sort == "trending_score"
else "createdAt"
if sort == "created_at"
else sort
)
if direction is not None:
params["direction"] = direction
if limit is not None:
Expand Down Expand Up @@ -2193,8 +2209,8 @@ def list_spaces(
linked (`bool`, *optional*):
Whether to return Spaces that make use of either a model or a dataset.
sort (`Literal["last_modified"]` or `str`, *optional*):
The key with which to sort the resulting Spaces. Possible
values are the properties of the [`huggingface_hub.hf_api.SpaceInfo`]` class.
The key with which to sort the resulting models. Possible values are "last_modified", "trending_score",
"created_at" and "likes".
direction (`Literal[-1]` or `int`, *optional*):
Direction in which to sort. The value `-1` sorts by descending
order while all other values sort by ascending order.
Expand Down Expand Up @@ -2230,7 +2246,15 @@ def list_spaces(
if search is not None:
params["search"] = search
if sort is not None:
params["sort"] = "lastModified" if sort == "last_modified" else sort
params["sort"] = (
"lastModified"
if sort == "last_modified"
else "trendingScore"
if sort == "trending_score"
else "createdAt"
if sort == "created_at"
else sort
)
if direction is not None:
params["direction"] = direction
if limit is not None:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,30 @@ def test_list_models_complex_query(self):
assert isinstance(model, ModelInfo)
assert all(tag in model.tags for tag in ["bert", "jax"])

def test_list_models_sort_trending_score(self):
models = list(self._api.list_models(sort="trending_score", limit=10))
assert len(models) == 10
assert isinstance(models[0], ModelInfo)
assert all(model.trending_score is not None for model in models)

def test_list_models_sort_created_at(self):
models = list(self._api.list_models(sort="created_at", limit=10))
assert len(models) == 10
assert isinstance(models[0], ModelInfo)
assert all(model.created_at is not None for model in models)

def test_list_models_sort_downloads(self):
models = list(self._api.list_models(sort="downloads", limit=10))
assert len(models) == 10
assert isinstance(models[0], ModelInfo)
assert all(model.downloads is not None for model in models)

def test_list_models_sort_likes(self):
models = list(self._api.list_models(sort="likes", limit=10))
assert len(models) == 10
assert isinstance(models[0], ModelInfo)
assert all(model.likes is not None for model in models)

def test_list_models_with_config(self):
for model in self._api.list_models(filter=("adapter-transformers", "bert"), fetch_config=True, limit=20):
self.assertIsNotNone(model.config)
Expand Down

0 comments on commit 9fb6a8c

Please sign in to comment.