Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Retrieve collations from the schema (and refactor the column info structures) #814

Merged
merged 3 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from rich.logging import RichHandler
import click

from data_diff import Database
from data_diff.schema import create_schema
from data_diff import Database, DbPath
from data_diff.schema import RawColumnInfo, create_schema
from data_diff.queries.api import current_timestamp

from data_diff.dbt import dbt_diff
Expand Down Expand Up @@ -72,7 +72,7 @@ def _remove_passwords_in_dict(d: dict) -> None:
d[k] = remove_password_from_url(v)


def _get_schema(pair):
def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
db, table_path = pair
return db.query_table_schema(table_path)

Expand Down
88 changes: 87 additions & 1 deletion data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Type, TypeVar, Union
from typing import Collection, List, Optional, Tuple, Type, TypeVar, Union
from datetime import datetime

import attrs
Expand All @@ -15,6 +15,91 @@
N = TypeVar("N")


@attrs.frozen(kw_only=True, eq=False, order=False, unsafe_hash=True)
class Collation:
"""
A pre-parsed or pre-known record about db collation, per column.

The "greater" collation should be used as a target collation for textual PKs
on both sides of the diff — by coverting the "lesser" collation to self.

Snowflake easily absorbs the performance losses, so it has a boost to always
be greater than any other collation in non-Snowflake databases.
Other databases need to negotiate which side absorbs the performance impact.
"""

# A boost for special databases that are known to absorb the performance dmaage well.
absorbs_damage: bool = False

# Ordinal soring by ASCII/UTF8 (True), or alphabetic as per locale/country/etc (False).
ordinal: Optional[bool] = None

# Lowercase first (aAbBcC or abcABC). Otherwise, uppercase first (AaBbCc or ABCabc).
lower_first: Optional[bool] = None

# 2-letter lower-case locale and upper-case country codes, e.g. en_US. Ignored for ordinals.
language: Optional[str] = None
country: Optional[str] = None

# There are also space-, punctuation-, width-, kana-(in)sensitivity, so on.
# Ignore everything not related to xdb alignment. Only case- & accent-sensitivity are common.
case_sensitive: Optional[bool] = None
accent_sensitive: Optional[bool] = None

# Purely informational, for debugging:
_source: Union[None, str, Collection[str]] = None

def __eq__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
if self.ordinal and other.ordinal:
# TODO: does it depend on language? what does Albanic_BIN mean in MS SQL?
return True
return (
self.language == other.language
and (self.country is None or other.country is None or self.country == other.country)
and self.case_sensitive == other.case_sensitive
and self.accent_sensitive == other.accent_sensitive
and self.lower_first == other.lower_first
)

def __ne__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
return not self.__eq__(other)

def __gt__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
if self == other:
return False
if self.absorbs_damage and not other.absorbs_damage:
return False
if other.absorbs_damage and not self.absorbs_damage:
return True # this one is preferred if it cannot absorb damage as its counterpart can
if self.ordinal and not other.ordinal:
return True
if other.ordinal and not self.ordinal:
return False
# TODO: try to align the languages & countries?
return False

def __ge__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
return self == other or self.__gt__(other)

def __lt__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
return self != other and not self.__gt__(other)

def __le__(self, other: object) -> bool:
if not isinstance(other, Collation):
return NotImplemented
return self == other or not self.__gt__(other)


@attrs.define(frozen=True, kw_only=True)
class ColType:
# Arbitrary metadata added and fetched at runtime.
Expand Down Expand Up @@ -112,6 +197,7 @@ def python_type(self) -> type:
@attrs.define(frozen=True)
class StringType(ColType):
python_type = str
collation: Optional[Collation] = attrs.field(default=None, kw_only=True)


@attrs.define(frozen=True)
Expand Down
62 changes: 35 additions & 27 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from data_diff.abcs.compiler import AbstractCompiler, Compilable
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
from data_diff.schema import RawColumnInfo
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
from data_diff.queries.ast_classes import (
Expand Down Expand Up @@ -707,27 +708,18 @@ def type_repr(self, t) -> str:
datetime: "TIMESTAMP",
}[t]

def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
return self.TYPE_CLASSES.get(type_repr)

def parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
"Parse type info as returned by the database"

cls = self._parse_type_repr(type_repr)
cls = self.TYPE_CLASSES.get(info.data_type)
if cls is None:
return UnknownColType(type_repr)
return UnknownColType(info.data_type)

if issubclass(cls, TemporalType):
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
precision=info.datetime_precision
if info.datetime_precision is not None
else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

Expand All @@ -738,22 +730,22 @@ def parse_type(
return cls()

elif issubclass(cls, Decimal):
if numeric_scale is None:
numeric_scale = 0 # Needed for Oracle.
return cls(precision=numeric_scale)
if info.numeric_scale is None:
return cls(precision=0) # Needed for Oracle.
return cls(precision=info.numeric_scale)

elif issubclass(cls, Float):
# assert numeric_scale is None
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.")

def _convert_db_precision_to_digits(self, p: int) -> int:
"""Convert from binary precision, used by floats, to decimal precision."""
Expand Down Expand Up @@ -1018,7 +1010,7 @@ def select_table_schema(self, path: DbPath) -> str:
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
"""Query the table for its schema for table in 'path', and return {column: tuple}
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)

Expand All @@ -1029,7 +1021,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

d = {r[0]: r for r in rows}
d = {
r[0]: RawColumnInfo(
column_name=r[0],
data_type=r[1],
datetime_precision=r[2],
numeric_precision=r[3],
numeric_scale=r[4],
collation_name=r[5] if len(r) > 5 else None,
)
for r in rows
}
assert len(d) == len(rows)
return d

Expand All @@ -1051,7 +1053,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
return list(res)

def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None
self,
path: DbPath,
raw_schema: Dict[str, RawColumnInfo],
filter_columns: Sequence[str] = None,
where: str = None,
):
"""Process the result of query_table_schema().

Expand All @@ -1067,7 +1073,7 @@ def _process_table_schema(
accept = {i.lower() for i in filter_columns}
filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}

col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()}
col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}

self._refine_coltypes(path, col_dict, where)

Expand All @@ -1076,15 +1082,15 @@ def _process_table_schema(

def _refine_coltypes(
self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
):
) -> Dict[str, ColType]:
"""Refine the types in the column dict, by querying the database for a sample of their values

'where' restricts the rows to be sampled.
"""

text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
if not text_columns:
return
return col_dict

fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]

Expand Down Expand Up @@ -1116,7 +1122,9 @@ def _refine_coltypes(
)
else:
assert col_name in col_dict
col_dict[col_name] = String_VaryingAlphanum()
col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation)

return col_dict

def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
Expand Down
19 changes: 7 additions & 12 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MD5_HEXDIGITS,
)
from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
from data_diff.schema import RawColumnInfo


@import_helper(text="Please install BigQuery and configure your google-cloud access.")
Expand Down Expand Up @@ -91,27 +92,21 @@ def type_repr(self, t) -> str:
except KeyError:
return super().type_repr(t)

def parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
*args: Any, # pass-through args
**kwargs: Any, # pass-through args
) -> ColType:
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
col_type = super().parse_type(table_path, info)
if isinstance(col_type, UnknownColType):
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
m = self.TYPE_ARRAY_RE.fullmatch(info.data_type)
if m:
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
item_info = attrs.evolve(info, data_type=m.group(1))
item_type = self.parse_type(table_path, item_info)
col_type = Array(item_type=item_type)

# We currently ignore structs' structure, but later can parse it too. Examples:
# - STRUCT<INT64, STRING(10)> (unnamed)
# - STRUCT<foo INT64, bar STRING(10)> (named)
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
m = self.TYPE_STRUCT_RE.fullmatch(info.data_type)
if m:
col_type = Struct()

Expand Down
22 changes: 12 additions & 10 deletions data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from data_diff.abcs.database_types import (
ColType,
DbPath,
Decimal,
Float,
Integer,
Expand All @@ -24,6 +25,7 @@
Timestamp,
Boolean,
)
from data_diff.schema import RawColumnInfo

# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database
DEFAULT_DATABASE = "default"
Expand Down Expand Up @@ -75,19 +77,19 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# because it does not help for float with a big integer part.
return super()._convert_db_precision_to_digits(p) - 2

def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
nullable_prefix = "Nullable("
if type_repr.startswith(nullable_prefix):
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
if info.data_type.startswith(nullable_prefix):
info = attrs.evolve(info, data_type=info.data_type[len(nullable_prefix) :].rstrip(")"))

if type_repr.startswith("Decimal"):
type_repr = "Decimal"
elif type_repr.startswith("FixedString"):
type_repr = "FixedString"
elif type_repr.startswith("DateTime64"):
type_repr = "DateTime64"
if info.data_type.startswith("Decimal"):
info = attrs.evolve(info, data_type="Decimal")
elif info.data_type.startswith("FixedString"):
info = attrs.evolve(info, data_type="FixedString")
elif info.data_type.startswith("DateTime64"):
info = attrs.evolve(info, data_type="DateTime64")

return self.TYPE_CLASSES.get(type_repr)
return super().parse_type(table_path, info)

# def timestamp_value(self, t: DbTime) -> str:
# # return f"'{t}'"
Expand Down
Loading