Skip to content

feat: Allow semantic comparison of schemas #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions dataframely/_rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2025-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable

Expand All @@ -15,6 +17,17 @@ class Rule:
def __init__(self, expr: pl.Expr) -> None:
self.expr = expr

def matches(self, other: Rule) -> bool:
"""Check whether this rule semantically matches another rule.

Args:
other: The rule to compare with.

Returns:
Whether the rules are semantically equal.
"""
return self.expr.meta.eq(other.expr)


class GroupRule(Rule):
"""Rule that is evaluated on a group of columns."""
Expand All @@ -23,6 +36,11 @@ def __init__(self, expr: pl.Expr, group_columns: list[str]) -> None:
super().__init__(expr)
self.group_columns = group_columns

def matches(self, other: Rule) -> bool:
if not isinstance(other, GroupRule):
return False
return super().matches(other) and self.group_columns == other.group_columns


def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction], Rule]:
"""Mark a function as a rule to evaluate during validation.
Expand Down
69 changes: 62 additions & 7 deletions dataframely/columns/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

from __future__ import annotations

import inspect
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Callable
from typing import Any
from typing import Any, TypeAlias

import polars as pl

Expand All @@ -15,6 +16,12 @@
from dataframely._polars import PolarsDataType
from dataframely.random import Generator

Check: TypeAlias = (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
)

# ------------------------------------------------------------------------------------ #
# COLUMNS #
# ------------------------------------------------------------------------------------ #
Expand All @@ -32,12 +39,7 @@ def __init__(
*,
nullable: bool | None = None,
primary_key: bool = False,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -257,7 +259,60 @@ def _null_probability(self) -> float:
"""Private utility for the null probability used during sampling."""
return 0.1 if self.nullable else 0

# ----------------------------------- EQUALITY ----------------------------------- #

def matches(self, other: Column, expr: pl.Expr) -> bool:
"""Check whether this column semantically matches another column.

Args:
other: The column to compare with.
expr: An expression referencing the column to encode. This is required to
properly evaluate the equivalence of custom checks.

Returns:
Whether the columns are semantically equal.
"""
if not isinstance(other, self.__class__):
return False

attributes = inspect.signature(self.__class__.__init__)
return all(
self._attributes_match(
getattr(self, attr), getattr(other, attr), attr, expr
)
for attr in attributes.parameters
# NOTE: We do not want to compare the `alias` here as the comparison should
# only evaluate the type and its constraints. Names are checked in
# :meth:`Schema.matches`.
if attr not in ("self", "alias")
)

def _attributes_match(
self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr
) -> bool:
if name == "check":
return _compare_checks(lhs, rhs, column_expr)
return lhs == rhs

# -------------------------------- DUNDER METHODS -------------------------------- #

def __str__(self) -> str:
return self.__class__.__name__.lower()


def _compare_checks(lhs: Check | None, rhs: Check | None, expr: pl.Expr) -> bool:
match (lhs, rhs):
case (None, None):
return True
case (list(), list()):
return len(lhs) == len(rhs) and all(
left(expr).meta.eq(right(expr)) for left, right in zip(lhs, rhs)
)
case (dict(), dict()):
return lhs.keys() == rhs.keys() and all(
lhs[key](expr).meta.eq(rhs[key](expr)) for key in lhs.keys()
)
case _ if callable(lhs) and callable(rhs):
return lhs(expr).meta.eq(rhs(expr))
case _:
return False
11 changes: 2 additions & 9 deletions dataframely/columns/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

from __future__ import annotations

from collections.abc import Callable

import polars as pl

from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine
from dataframely._polars import PolarsDataType
from dataframely.random import Generator

from ._base import Column
from ._base import Check, Column


class Any(Column):
Expand All @@ -25,12 +23,7 @@ class Any(Column):
def __init__(
self,
*,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down
20 changes: 11 additions & 9 deletions dataframely/columns/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from __future__ import annotations

import math
from collections.abc import Callable, Sequence
from typing import Any, Literal
from collections.abc import Sequence
from typing import Any, Literal, cast

import polars as pl

from dataframely._compat import pa, sa, sa_TypeEngine
from dataframely.random import Generator

from ._base import Column
from ._base import Check, Column
from .struct import Struct


Expand All @@ -28,12 +28,7 @@ def __init__(
# polars doesn't yet support grouping by arrays,
# see https://github.com/pola-rs/polars/issues/22574
primary_key: Literal[False] = False,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -117,3 +112,10 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
all_elements.reshape((n, *self.shape)),
null_probability=self._null_probability,
)

def _attributes_match(
self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr
) -> bool:
if name == "inner":
return cast(Column, lhs).matches(cast(Column, rhs), pl.element())
return super()._attributes_match(lhs, rhs, name, column_expr)
43 changes: 17 additions & 26 deletions dataframely/columns/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

import datetime as dt
from collections.abc import Callable
from typing import Any, cast

import polars as pl
Expand All @@ -20,7 +19,7 @@
)
from dataframely.random import Generator

from ._base import Column
from ._base import Check, Column
from ._mixins import OrdinalMixin
from ._utils import first_non_null, map_optional

Expand All @@ -40,12 +39,7 @@ def __init__(
max: dt.date | None = None,
max_exclusive: dt.date | None = None,
resolution: str | None = None,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -167,12 +161,7 @@ def __init__(
max: dt.time | None = None,
max_exclusive: dt.time | None = None,
resolution: str | None = None,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -302,12 +291,7 @@ def __init__(
resolution: str | None = None,
time_zone: str | dt.tzinfo | None = None,
time_unit: TimeUnit = "us",
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -425,6 +409,18 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
null_probability=self._null_probability,
)

def _attributes_match(
self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr
) -> bool:
if (
name == "time_zone"
and isinstance(lhs, dt.tzinfo)
and isinstance(rhs, dt.tzinfo)
):
now = dt.datetime.now()
return lhs.utcoffset(now) == rhs.utcoffset(now)
return super()._attributes_match(lhs, rhs, name, column_expr)


class Duration(OrdinalMixin[dt.timedelta], Column):
"""A column of durations."""
Expand All @@ -439,12 +435,7 @@ def __init__(
max: dt.timedelta | None = None,
max_exclusive: dt.timedelta | None = None,
resolution: str | None = None,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down
10 changes: 2 additions & 8 deletions dataframely/columns/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import decimal
import math
from collections.abc import Callable
from typing import Any

import polars as pl
Expand All @@ -14,7 +13,7 @@
from dataframely._polars import PolarsDataType
from dataframely.random import Generator

from ._base import Column
from ._base import Check, Column
from ._mixins import OrdinalMixin
from ._utils import first_non_null, map_optional

Expand All @@ -33,12 +32,7 @@ def __init__(
min_exclusive: decimal.Decimal | None = None,
max: decimal.Decimal | None = None,
max_exclusive: decimal.Decimal | None = None,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down
11 changes: 3 additions & 8 deletions dataframely/columns/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any

import polars as pl
Expand All @@ -12,7 +12,7 @@
from dataframely._polars import PolarsDataType
from dataframely.random import Generator

from ._base import Column
from ._base import Check, Column


class Enum(Column):
Expand All @@ -24,12 +24,7 @@ def __init__(
*,
nullable: bool | None = None,
primary_key: bool = False,
check: (
Callable[[pl.Expr], pl.Expr]
| list[Callable[[pl.Expr], pl.Expr]]
| dict[str, Callable[[pl.Expr], pl.Expr]]
| None
) = None,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
Expand Down
Loading
Loading