Skip to content

Commit

Permalink
feat: Use pydantic for validation when reading (#18)
Browse files Browse the repository at this point in the history
* chore: add pydantic

* feat: use pydantic for validation when reading

* feat: support frozen dataclasses

* test: test class
  • Loading branch information
msto authored Oct 21, 2024
1 parent 377bcc7 commit 9e0ad05
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 1,176 deletions.
38 changes: 32 additions & 6 deletions dataclass_io/_lib/dataclass_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import ClassVar
from typing import Protocol

from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.dataclasses import is_pydantic_dataclass


class DataclassInstance(Protocol):
"""
Expand Down Expand Up @@ -42,11 +45,34 @@ def row_to_dataclass(
dataclass_type: The dataclass to which the row will be casted.
"""

coerced_values: dict[str, Any] = {}
data: DataclassInstance

# TODO support classes which inherit from `pydantic.BaseModel`
if is_pydantic_dataclass(dataclass_type):
# If we received a pydantic dataclass, we can simply use its validation
data = dataclass_type(**row)
else:
# If we received a stdlib dataclass, we use pydantic's dataclass decorator to create a
# version of the dataclass with validation. We instantiate from this version to take
# advantage of pydantic's validation, but then unpack the validated data in order to return
# an instance of the user-specified dataclass.

params = dataclass_type.__dataclass_params__ # type:ignore[attr-defined]

pydantic_cls = pydantic_dataclass(
_cls=dataclass_type,
repr=params.repr,
eq=params.eq,
order=params.order,
unsafe_hash=params.unsafe_hash,
frozen=params.frozen,
)

validated_data = pydantic_cls(**row)
unpacked_data = {
field.name: getattr(validated_data, field.name) for field in fields(dataclass_type)
}

# Coerce each value in the row to the type of the corresponding field
for field in fields(dataclass_type):
value = row[field.name]
coerced_values[field.name] = field.type(value)
data = dataclass_type(**unpacked_data)

return dataclass_type(**coerced_values)
return data
Loading

0 comments on commit 9e0ad05

Please sign in to comment.