Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/en/package_reference/dataclasses.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ The `@strict` decorator enhances a dataclass with strict validation.

[[autodoc]] dataclasses.strict

### `validate_typed_dict`

Method to validate that a dictionary conforms to the types defined in a `TypedDict` class.

This is the equivalent to dataclass validation but for `TypedDict`s. Since typed dicts are never instantiated (only used by static type checkers), validation step must be manually called.

[[autodoc]] dataclasses.validate_typed_dict

### `as_validated_field`

Decorator to create a [`validated_field`]. Recommended for fields with a single validator to avoid boilerplate code.
Expand Down
135 changes: 132 additions & 3 deletions src/huggingface_hub/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
import inspect
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields
from functools import wraps
from typing import Any, Callable, ForwardRef, Literal, Optional, Type, TypeVar, Union, get_args, get_origin, overload
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
from functools import lru_cache, wraps
from typing import (
Annotated,
Any,
Callable,
ForwardRef,
Literal,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
overload,
)


try:
# Python 3.11+
from typing import NotRequired, Required # type: ignore
except ImportError:
try:
# In case typing_extensions is installed
from typing_extensions import NotRequired, Required # type: ignore
except ImportError:
# Fallback: create dummy types that will never match
Required = type("Required", (), {}) # type: ignore
NotRequired = type("NotRequired", (), {}) # type: ignore

from .errors import (
StrictDataclassClassValidationError,
Expand All @@ -12,6 +38,9 @@

Validator_T = Callable[[Any], None]
T = TypeVar("T")
TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])

_TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)


# The overload decorator helps type checkers understand the different return types
Expand Down Expand Up @@ -223,6 +252,92 @@ def init_with_validate(self, *args, **kwargs) -> None:
return wrap(cls) if cls is not None else wrap


def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
"""
Validate that a dictionary conforms to the types defined in a TypedDict class.

Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.

Args:
schema (`type[TypedDictType]`):
The TypedDict class defining the expected structure and types.
data (`dict`):
The dictionary to validate.

Raises:
`StrictDataclassFieldValidationError`:
If any field in the dictionary does not conform to the expected type.

Example:
```py
>>> from typing import Annotated, TypedDict
>>> from huggingface_hub.dataclasses import validate_typed_dict

>>> def positive_int(value: int):
... if not value >= 0:
... raise ValueError(f"Value must be positive, got {value}")

>>> class User(TypedDict):
... name: str
... age: Annotated[int, positive_int]

>>> # Valid data
>>> validate_typed_dict(User, {"name": "John", "age": 30})

>>> # Invalid type for age
>>> validate_typed_dict(User, {"name": "John", "age": "30"})
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
TypeError: Field 'age' expected int, got str (value: '30')

>>> # Invalid value for age
>>> validate_typed_dict(User, {"name": "John", "age": -1})
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
ValueError: Value must be positive, got -1
```
"""
# Convert typed dict to dataclass
strict_cls = _build_strict_cls_from_typed_dict(schema)

# Validate the data by instantiating the strict dataclass
strict_cls(**data) # will raise if validation fails


@lru_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
# Extract type hints from the TypedDict class
type_hints = {
# We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
# ForwardRefs are not validated by @strict anyway.
name: value if value is not None else type(None)
for name, value in schema.__dict__.get("__annotations__", {}).items()
}

# If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
if not getattr(schema, "__total__", True):
for key, value in type_hints.items():
origin = get_origin(value)

if origin is Annotated:
base, *meta = get_args(value)
if not _is_required_or_notrequired(base):
base = NotRequired[base]
type_hints[key] = Annotated[tuple([base] + list(meta))]
elif not _is_required_or_notrequired(value):
type_hints[key] = NotRequired[value]

# Convert type hints to dataclass fields
fields = []
for key, value in type_hints.items():
if get_origin(value) is Annotated:
base, *meta = get_args(value)
fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
else:
fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))

# Create a strict dataclass from the TypedDict fields
return strict(make_dataclass(schema.__name__, fields))


def validated_field(
validator: Union[list[Validator_T], Validator_T],
default: Union[Any, _MISSING_TYPE] = MISSING,
Expand Down Expand Up @@ -313,6 +428,14 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
_validate_simple_type(name, value, expected_type)
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
return
elif origin is Required:
if value is _TYPED_DICT_DEFAULT_VALUE:
raise TypeError(f"Field '{name}' is required but missing.")
_validate_simple_type(name, value, args[0])
elif origin is NotRequired:
if value is _TYPED_DICT_DEFAULT_VALUE:
return
_validate_simple_type(name, value, args[0])
else:
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")

Expand Down Expand Up @@ -449,6 +572,11 @@ def _is_validator(validator: Any) -> bool:
return True


def _is_required_or_notrequired(type_hint: Any) -> bool:
"""Helper to check if a type is Required/NotRequired."""
return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))


_BASIC_TYPE_VALIDATORS = {
Union: _validate_union,
Literal: _validate_literal,
Expand All @@ -461,6 +589,7 @@ def _is_validator(validator: Any) -> bool:

__all__ = [
"strict",
"validate_typed_dict",
"validated_field",
"Validator_T",
"StrictDataclassClassValidationError",
Expand Down
134 changes: 132 additions & 2 deletions tests/test_utils_strict_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
import inspect
import sys
from dataclasses import asdict, astuple, dataclass, is_dataclass
from typing import Any, Literal, Optional, Union, get_type_hints
from typing import Annotated, Any, Literal, Optional, TypedDict, Union, get_type_hints

import jedi
import pytest

from huggingface_hub.dataclasses import _is_validator, as_validated_field, strict, type_validator, validated_field

if sys.version_info >= (3, 11):
from typing import NotRequired, Required
else:
# Provide fallbacks or skip the entire module
NotRequired = None
Required = None
from huggingface_hub.dataclasses import (
_build_strict_cls_from_typed_dict,
_is_validator,
as_validated_field,
strict,
type_validator,
validate_typed_dict,
validated_field,
)
from huggingface_hub.errors import (
StrictDataclassClassValidationError,
StrictDataclassDefinitionError,
Expand Down Expand Up @@ -646,3 +662,117 @@ def validate(self):
@dataclass
class ConfigWithParent(ParentClass): # 'validate' already defined => should raise an error
foo: int = 0


class ConfigDict(TypedDict):
str_value: str
positive_int_value: Annotated[int, positive_int]
forward_ref_value: "ForwardDtype"
optional_value: Optional[int]


@pytest.mark.parametrize(
"data",
[
# All values are valid
{"str_value": "foo", "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
],
)
def test_typed_dict_valid_data(data: dict):
validate_typed_dict(ConfigDict, data)


@pytest.mark.parametrize(
"data",
[
# Optional value cannot be omitted
{"str_value": "foo", "positive_int_value": 1, "forward_ref_value": "bar"},
# Other fields neither
{"positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
# Not a string
{"str_value": 123, "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0},
# Not an integer
{"str_value": "foo", "positive_int_value": "1", "forward_ref_value": "bar", "optional_value": 0},
# Annotated validator is used
{"str_value": "foo", "positive_int_value": -1, "forward_ref_value": "bar", "optional_value": 0},
],
)
def test_typed_dict_invalid_data(data: dict):
with pytest.raises(StrictDataclassFieldValidationError):
validate_typed_dict(ConfigDict, data)


def test_typed_dict_error_message():
with pytest.raises(StrictDataclassFieldValidationError) as exception:
validate_typed_dict(
ConfigDict, {"str_value": 123, "positive_int_value": 1, "forward_ref_value": "bar", "optional_value": 0}
)
assert "Validation error for field 'str_value'" in str(exception.value)
assert "Field 'str_value' expected str, got int (value: 123)" in str(exception.value)


def test_typed_dict_unknown_attribute():
with pytest.raises(TypeError):
validate_typed_dict(
ConfigDict,
{
"str_value": "foo",
"positive_int_value": 1,
"forward_ref_value": "bar",
"optional_value": 0,
"another_value": 0,
},
)


def test_typed_dict_to_dataclass_is_cached():
strict_cls = _build_strict_cls_from_typed_dict(ConfigDict)
strict_cls_bis = _build_strict_cls_from_typed_dict(ConfigDict)
assert strict_cls is strict_cls_bis # "is" because dataclass is built only once


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python 3.11+")
class TestConfigDictNotRequired:
def __init__(self):
# cannot be defined at class level because of Python<3.11
self.ConfigDictNotRequired = TypedDict(
"ConfigDictNotRequired",
{"required_value": Required[int], "not_required_value": NotRequired[int]},
total=False,
)

@pytest.mark.parametrize(
"data",
[
{"required_value": 1, "not_required_value": 2},
{"required_value": 1}, # not required value is not validated
],
)
def test_typed_dict_not_required_valid_data(self, data: dict):
validate_typed_dict(self.ConfigDictNotRequired, data)

@pytest.mark.parametrize(
"data",
[
# Missing required value
{"not_required_value": 2},
# If exists, the value is validated
{"required_value": 1, "not_required_value": "2"},
],
)
def test_typed_dict_not_required_invalid_data(self, data: dict):
with pytest.raises(StrictDataclassFieldValidationError):
validate_typed_dict(self.ConfigDictNotRequired, data)


def test_typed_dict_total_true():
ConfigDictTotalTrue = TypedDict("ConfigDictTotalTrue", {"value": int}, total=True)
validate_typed_dict(ConfigDictTotalTrue, {"value": 1})
with pytest.raises(StrictDataclassFieldValidationError):
validate_typed_dict(ConfigDictTotalTrue, {})


def test_typed_dict_total_false():
ConfigDictTotalFalse = TypedDict("ConfigDictTotalFalse", {"value": int}, total=False)
validate_typed_dict(ConfigDictTotalFalse, {})
validate_typed_dict(ConfigDictTotalFalse, {"value": 1})