Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/huggingface_hub/utils/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Handle typing imports based on system compatibility."""

import sys
from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, List, Literal, Optional, Set, Type, TypeVar, Union, get_args, get_origin


UNION_TYPES: List[Any] = [Union]
Expand All @@ -33,7 +33,7 @@
_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None))


def is_jsonable(obj: Any) -> bool:
def is_jsonable(obj: Any, _visited: Optional[Set[int]] = None) -> bool:
"""Check if an object is JSON serializable.

This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object.
Expand All @@ -43,19 +43,39 @@ def is_jsonable(obj: Any) -> bool:
- it is an instance of int, float, str, bool, or NoneType
- it is a list or tuple and all its items are json serializable
- it is a dict and all its keys are strings and all its values are json serializable

Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is
considered not json serializable.
"""
# Initialize visited set to track object ids and detect circular references
if _visited is None:
_visited = set()

# Detect circular reference
obj_id = id(obj)
if obj_id in _visited:
return False

# Add current object to visited before recursive checks
_visited.add(obj_id)
try:
if isinstance(obj, _JSON_SERIALIZABLE_TYPES):
return True
if isinstance(obj, (list, tuple)):
return all(is_jsonable(item) for item in obj)
return all(is_jsonable(item, _visited) for item in obj)
if isinstance(obj, dict):
return all(isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value) for key, value in obj.items())
return all(
isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited)
for key, value in obj.items()
)
if hasattr(obj, "__json__"):
return True
return False
except RecursionError:
return False
finally:
# Remove the object id from visited to avoid side‑effects for other branches
_visited.discard(obj_id)


def is_simple_optional_type(type_: Type) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_utils_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class CustomType:
OBJ_WITH_CIRCULAR_REF = {"hello": "world"}
OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF

_nested = {"hello": "world"}
OBJ_WITHOUT_CIRCULAR_REF = {"hello": _nested, "world": [_nested]}


@pytest.mark.parametrize(
"data",
Expand All @@ -33,6 +36,7 @@ class CustomType:
{},
{"name": "Alice", "age": 30},
{0: "LABEL_0", 1.0: "LABEL_1"},
OBJ_WITHOUT_CIRCULAR_REF,
],
)
def test_is_jsonable_success(data):
Expand Down
Loading