Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion airflow-core/src/airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import airflow.serialization.serializers
from airflow.configuration import conf
from airflow.serialization.typing import is_pydantic_model
from airflow.stats import Stats
from airflow.utils.module_loading import import_string, iter_namespace, qualname

Expand All @@ -52,6 +53,7 @@
OLD_SOURCE = "__source"
OLD_DATA = "__var"
OLD_DICT = "dict"
PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"

DEFAULT_VERSION = 0

Expand Down Expand Up @@ -145,6 +147,12 @@ def serialize(o: object, depth: int = 0) -> U | None:
qn = "builtins.tuple"
classname = qn

if is_pydantic_model(o):
# to match the generic Pydantic serializer and deserializer in _serializers and _deserializers
qn = PYDANTIC_MODEL_QUALNAME
# the actual Pydantic model class to encode
classname = qualname(o)

# if there is a builtin serializer available use that
if qn in _serializers:
data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
Expand Down Expand Up @@ -256,7 +264,10 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:

# registered deserializer
if classname in _deserializers:
return _deserializers[classname].deserialize(classname, version, deserialize(value))
return _deserializers[classname].deserialize(cls, version, deserialize(value))
if is_pydantic_model(cls):
if PYDANTIC_MODEL_QUALNAME in _deserializers:
return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value))

# class has deserialization function
if hasattr(cls, "deserialize"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return float(o), name, __version__, True


def deserialize(classname: str, version: int, data: object) -> decimal.Decimal:
def deserialize(cls: type, version: int, data: object) -> decimal.Decimal:
from decimal import Decimal

if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")

if classname != qualname(Decimal):
raise TypeError(f"{classname} != {qualname(Decimal)}")
if cls is not Decimal:
raise TypeError(f"do not know how to deserialize {qualname(cls)}")

return Decimal(str(data))
12 changes: 6 additions & 6 deletions airflow-core/src/airflow/serialization/serializers/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return list(cast("list", o)), qualname(o), __version__, True


def deserialize(classname: str, version: int, data: list) -> tuple | set | frozenset:
def deserialize(cls: type, version: int, data: list) -> tuple | set | frozenset:
if version > __version__:
raise TypeError("serialized version is newer than class version")
raise TypeError(f"serialized version {version} is newer than class version {__version__}")

if classname == qualname(tuple):
if cls is tuple:
return tuple(data)

if classname == qualname(set):
if cls is set:
return set(data)

if classname == qualname(frozenset):
if cls is frozenset:
return frozenset(data)

raise TypeError(f"do not know how to deserialize {classname}")
raise TypeError(f"do not know how to deserialize {qualname(cls)}")


def stringify(classname: str, version: int, data: list) -> str:
Expand Down
12 changes: 6 additions & 6 deletions airflow-core/src/airflow/serialization/serializers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False


def deserialize(classname: str, version: int, data: dict | str) -> datetime.date | datetime.timedelta:
def deserialize(cls: type, version: int, data: dict | str) -> datetime.date | datetime.timedelta:
import datetime

from pendulum import DateTime
Expand All @@ -86,16 +86,16 @@ def deserialize(classname: str, version: int, data: dict | str) -> datetime.date
else None
)

if classname == qualname(datetime.datetime) and isinstance(data, dict):
if cls is datetime.datetime and isinstance(data, dict):
return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)

if classname == qualname(DateTime) and isinstance(data, dict):
if cls is DateTime and isinstance(data, dict):
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)

if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)):
if cls is datetime.timedelta and isinstance(data, str | float):
return datetime.timedelta(seconds=float(data))

if classname == qualname(datetime.date) and isinstance(data, str):
if cls is datetime.date and isinstance(data, str):
return datetime.date.fromisoformat(data)

raise TypeError(f"unknown date/time format {classname}")
raise TypeError(f"unknown date/time format {qualname(cls)}")
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return data, qualname(o), __version__, True


def deserialize(classname: str, version: int, data: dict):
def deserialize(cls: type, version: int, data: dict):
from deltalake.table import DeltaTable

from airflow.models.crypto import get_fernet

if version > __version__:
raise TypeError("serialized version is newer than class version")

if classname == qualname(DeltaTable):
if cls is DeltaTable:
fernet = get_fernet()
properties = {}
for k, v in data["storage_options"].items():
Expand All @@ -76,4 +76,4 @@ def deserialize(classname: str, version: int, data: dict):

return DeltaTable(data["table_uri"], version=data["version"], storage_options=storage_options)

raise TypeError(f"do not know how to deserialize {classname}")
raise TypeError(f"do not know how to deserialize {qualname(cls)}")
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return data, qualname(o), __version__, True


def deserialize(classname: str, version: int, data: dict):
def deserialize(cls: type, version: int, data: dict):
from pyiceberg.catalog import load_catalog
from pyiceberg.table import Table

Expand All @@ -64,7 +64,7 @@ def deserialize(classname: str, version: int, data: dict):
if version > __version__:
raise TypeError("serialized version is newer than class version")

if classname == qualname(Table):
if cls is Table:
fernet = get_fernet()
properties = {}
for k, v in data["catalog_properties"].items():
Expand All @@ -73,4 +73,4 @@ def deserialize(classname: str, version: int, data: dict):
catalog = load_catalog(data["identifier"][0], **properties)
return catalog.load_table((data["identifier"][1], data["identifier"][2]))

raise TypeError(f"do not know how to deserialize {classname}")
raise TypeError(f"do not know how to deserialize {qualname(cls)}")
10 changes: 6 additions & 4 deletions airflow-core/src/airflow/serialization/serializers/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False


def deserialize(classname: str, version: int, data: str) -> Any:
def deserialize(cls: type, version: int, data: str) -> Any:
if version > __version__:
raise TypeError("serialized version is newer than class version")

if classname not in deserializers:
raise TypeError(f"unsupported {classname} found for numpy deserialization")
allowed_deserialize_classes = [import_string(classname) for classname in deserializers]

return import_string(classname)(data)
if cls not in allowed_deserialize_classes:
raise TypeError(f"unsupported {qualname(cls)} found for numpy deserialization")

return cls(data)
13 changes: 9 additions & 4 deletions airflow-core/src/airflow/serialization/serializers/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True


def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
def deserialize(cls: type, version: int, data: object) -> pd.DataFrame:
if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")

from pyarrow import parquet as pq
import pandas as pd

if cls is not pd.DataFrame:
raise TypeError(f"do not know how to deserialize {qualname(cls)}")

if not isinstance(data, str):
raise TypeError(f"serialized {classname} has wrong data type {type(data)}")
raise TypeError(f"serialized {qualname(cls)} has wrong data type {type(data)}")

from io import BytesIO

from pyarrow import parquet as pq

with BytesIO(bytes.fromhex(data)) as buf:
df = pq.read_table(buf).to_pandas()

Expand Down
75 changes: 75 additions & 0 deletions airflow-core/src/airflow/serialization/serializers/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.serialization.typing import is_pydantic_model
from airflow.utils.module_loading import qualname

if TYPE_CHECKING:
from airflow.serialization.serde import U

serializers = [
"pydantic.main.BaseModel",
]
deserializers = serializers

__version__ = 1


def serialize(o: object) -> tuple[U, str, int, bool]:
"""
Serialize a Pydantic BaseModel instance into a dict of built-in types.

Returns a tuple of:
- serialized data (as built-in types)
- fixed class name for registration (BaseModel)
- version number
- is_serialized flag (True if handled)
"""
if not is_pydantic_model(o):
return "", "", 0, False

data = o.model_dump() # type: ignore

return data, qualname(o), __version__, True


def deserialize(cls: type, version: int, data: dict):
"""
Deserialize a Pydantic class.

Pydantic models can be serialized into a Python dictionary via `pydantic.main.BaseModel.model_dump`
and the dictionary can be deserialized through `pydantic.main.BaseModel.model_validate`. This function
can deserialize arbitrary Pydantic models that are in `allowed_deserialization_classes`.

:param cls: The actual model class
:param version: Serialization version (must not exceed __version__)
:param data: Dictionary with built-in types, typically from model_dump()
:return: An instance of the actual Pydantic model
"""
if version > __version__:
raise TypeError(f"Serialized version {version} is newer than the supported version {__version__}")

if not is_pydantic_model(cls):
# no deserializer available
raise TypeError(f"No deserializer found for {qualname(cls)}")

# Perform validation-based reconstruction
return cls.model_validate(data) # type: ignore
10 changes: 5 additions & 5 deletions airflow-core/src/airflow/serialization/serializers/timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False


def deserialize(classname: str, version: int, data: object) -> Any:
def deserialize(cls: type, version: int, data: object) -> Any:
from zoneinfo import ZoneInfo

from airflow.utils.timezone import parse_timezone

if not isinstance(data, (str, int)):
raise TypeError(f"{data} is not of type int or str but of {type(data)}")

if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")

if classname == "backports.zoneinfo.ZoneInfo" and isinstance(data, str):
from zoneinfo import ZoneInfo
raise TypeError(f"serialized {version} of {qualname(cls)} > {__version__}")

if cls is ZoneInfo and isinstance(data, str):
return ZoneInfo(data)

return parse_timezone(data)
Expand Down
32 changes: 32 additions & 0 deletions airflow-core/src/airflow/serialization/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any


def is_pydantic_model(cls: Any) -> bool:
"""
Return True if the class is a pydantic.main.BaseModel.

Checking is done by attributes as it is significantly faster than
using isinstance.
"""
# __pydantic_fields__ is always present on Pydantic V2 models and is a dict[str, FieldInfo]
# __pydantic_validator__ is an internal validator object, always set after model build
return hasattr(cls, "__pydantic_fields__") and hasattr(cls, "__pydantic_validator__")
Loading