Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Do not detect MD5s as UUIDs, and preserve UUID casing for UUID PKs #813

Merged
merged 5 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def python_type(self) -> type:
"Return the equivalent Python type of the key"

def make_value(self, value):
if isinstance(value, self.python_type):
return value
return self.python_type(value)


Expand Down Expand Up @@ -217,7 +219,14 @@ class Native_UUID(ColType_UUID):

@attrs.define(frozen=True)
class String_UUID(ColType_UUID, StringType):
pass
# Case is important for UUIDs stored as regular string, not native UUIDs stored as numbers.
# We slice them internally as numbers, but render them back to SQL as lower/upper case.
# None means we do not know for sure, behave as with False, but it might be unreliable.
lowercase: Optional[bool] = None
uppercase: Optional[bool] = None

def make_value(self, v: str) -> ArithUUID:
return self.python_type(v, lowercase=self.lowercase, uppercase=self.uppercase)


@attrs.define(frozen=True)
Expand All @@ -230,9 +239,6 @@ def test_value(value: str) -> bool:
except ValueError:
return False

def make_value(self, value):
return self.python_type(value)


@attrs.define(frozen=True)
class String_VaryingAlphanum(String_Alphanum):
Expand All @@ -244,6 +250,8 @@ class String_FixedAlphanum(String_Alphanum):
length: int

def make_value(self, value):
if isinstance(value, self.python_type):
return value
if len(value) != self.length:
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")
return self.python_type(value, max_len=self.length)
Expand Down
14 changes: 11 additions & 3 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from data_diff.abcs.compiler import AbstractCompiler, Compilable
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
from data_diff.schema import RawColumnInfo
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
from data_diff.queries.ast_classes import (
Alias,
Expand Down Expand Up @@ -248,6 +248,9 @@ def _compile(self, compiler: Compiler, elem) -> str:
return self.timestamp_value(elem)
elif isinstance(elem, bytes):
return f"b'{elem.decode()}'"
elif isinstance(elem, ArithUUID):
s = f"'{elem.uuid}'"
return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s
elif isinstance(elem, ArithString):
return f"'{elem}'"
assert False, elem
Expand Down Expand Up @@ -681,8 +684,10 @@ def _constant_value(self, v):
return f"'{v}'"
elif isinstance(v, datetime):
return self.timestamp_value(v)
elif isinstance(v, UUID):
elif isinstance(v, UUID): # probably unused anymore in favour of ArithUUID
return f"'{v}'"
elif isinstance(v, ArithUUID):
return f"'{v.uuid}'"
elif isinstance(v, decimal.Decimal):
return str(v)
elif isinstance(v, bytearray):
Expand Down Expand Up @@ -1110,7 +1115,10 @@ def _refine_coltypes(
)
else:
assert col_name in col_dict
col_dict[col_name] = String_UUID()
col_dict[col_name] = String_UUID(
lowercase=all(s == s.lower() for s in uuid_samples),
uppercase=all(s == s.upper() for s in uuid_samples),
)
continue

if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)
Expand Down
11 changes: 7 additions & 4 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, in
# Start with the first completed value, so we don't waste time waiting
min_key1, max_key1 = self._parse_key_range_result(key_types1, next(key_ranges))

btable1, btable2 = [t.new_key_bounds(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
btable1 = table1.new_key_bounds(min_key=min_key1, max_key=max_key1, key_types=key_types1)
btable2 = table2.new_key_bounds(min_key=min_key1, max_key=max_key1, key_types=key_types2)

logger.info(
f"Diffing segments at key-range: {btable1.min_key}..{btable2.max_key}. "
Expand All @@ -324,16 +325,18 @@ def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, in
# └──┴──────┴──┘
# Overall, the max number of new regions in this 2nd pass is 3^|k| - 1

min_key2, max_key2 = self._parse_key_range_result(key_types1, next(key_ranges))
# Note: python types can be the same, but the rendering parameters (e.g. casing) can differ.
min_key2, max_key2 = self._parse_key_range_result(key_types2, next(key_ranges))

points = [list(sorted(p)) for p in safezip(min_key1, min_key2, max_key1, max_key2)]
box_mesh = create_mesh_from_points(*points)

new_regions = [(p1, p2) for p1, p2 in box_mesh if p1 < p2 and not (p1 >= min_key1 and p2 <= max_key1)]

for p1, p2 in new_regions:
extra_tables = [t.new_key_bounds(min_key=p1, max_key=p2) for t in (table1, table2)]
ti.submit(self._bisect_and_diff_segments, ti, *extra_tables, info_tree, priority=999)
extra_table1 = table1.new_key_bounds(min_key=p1, max_key=p2, key_types=key_types1)
extra_table2 = table2.new_key_bounds(min_key=p1, max_key=p2, key_types=key_types2)
ti.submit(self._bisect_and_diff_segments, ti, extra_table1, extra_table2, info_tree, priority=999)

return ti

Expand Down
13 changes: 10 additions & 3 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Container, Dict, List, Optional, Tuple
from typing import Container, Dict, List, Optional, Sequence, Tuple
import logging
from itertools import product

Expand All @@ -9,7 +9,7 @@
from data_diff.utils import safezip, Vector
from data_diff.utils import ArithString, split_space
from data_diff.databases.base import Database
from data_diff.abcs.database_types import DbPath, DbKey, DbTime
from data_diff.abcs.database_types import DbPath, DbKey, DbTime, IKey
from data_diff.schema import RawColumnInfo, Schema, create_schema
from data_diff.queries.extras import Checksum
from data_diff.queries.api import Count, SKIP, table, this, Expr, min_, max_, Code
Expand Down Expand Up @@ -205,7 +205,7 @@ def new(self, **kwargs) -> Self:
"""Creates a copy of the instance using 'replace()'"""
return attrs.evolve(self, **kwargs)

def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
def new_key_bounds(self, min_key: Vector, max_key: Vector, *, key_types: Optional[Sequence[IKey]] = None) -> Self:
if self.min_key is not None:
assert self.min_key <= min_key, (self.min_key, min_key)
assert self.min_key < max_key
Expand All @@ -214,6 +214,13 @@ def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
assert min_key < self.max_key
assert max_key <= self.max_key

# If asked, enforce the PKs to proper types, mainly to meta-params of the relevant side,
# so that we do not leak e.g. casing of UUIDs from side A to side B and vice versa.
# If not asked, keep the meta-params of the keys as is (assume them already casted).
if key_types is not None:
min_key = Vector(type.make_value(val) for type, val in safezip(key_types, min_key))
max_key = Vector(type.make_value(val) for type, val in safezip(key_types, max_key))

return attrs.evolve(self, min_key=min_key, max_key=max_key)

@property
Expand Down
77 changes: 68 additions & 9 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def safezip(*args):
return zip(*args)


def is_uuid(u):
UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", re.I)


def is_uuid(u: str) -> bool:
# E.g., hashlib.md5(b'hello') is a 32-letter hex number, but not an UUID.
# It would fail UUID-like comparison (< & >) because of casing and dashes.
if not UUID_PATTERN.fullmatch(u):
return False
try:
UUID(u)
except ValueError:
Expand Down Expand Up @@ -128,23 +135,75 @@ def range(self, other: "ArithString", count: int) -> List[Self]:
return [self.new(int=i) for i in checkpoints]


# @attrs.define # not as long as it inherits from UUID
class ArithUUID(UUID, ArithString):
def _any_to_uuid(v: Union[str, int, UUID, "ArithUUID"]) -> UUID:
if isinstance(v, ArithUUID):
return v.uuid
elif isinstance(v, UUID):
return v
elif isinstance(v, str):
return UUID(v)
elif isinstance(v, int):
return UUID(int=v)
else:
raise ValueError(f"Cannot convert a value to UUID: {v!r}")


@attrs.define(frozen=True, eq=False, order=False)
class ArithUUID(ArithString):
"A UUID that supports basic arithmetic (add, sub)"

uuid: UUID = attrs.field(converter=_any_to_uuid)
lowercase: Optional[bool] = None
uppercase: Optional[bool] = None

def range(self, other: "ArithUUID", count: int) -> List[Self]:
assert isinstance(other, ArithUUID)
checkpoints = split_space(self.uuid.int, other.uuid.int, count)
return [attrs.evolve(self, uuid=i) for i in checkpoints]

def __int__(self):
return self.int
return self.uuid.int

def __add__(self, other: int) -> Self:
if isinstance(other, int):
return self.new(int=self.int + other)
return attrs.evolve(self, uuid=self.uuid.int + other)
return NotImplemented

def __sub__(self, other: Union[UUID, int]):
def __sub__(self, other: Union["ArithUUID", int]):
if isinstance(other, int):
return self.new(int=self.int - other)
elif isinstance(other, UUID):
return self.int - other.int
return attrs.evolve(self, uuid=self.uuid.int - other)
elif isinstance(other, ArithUUID):
return self.uuid.int - other.uuid.int
return NotImplemented

def __eq__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid == other.uuid
return NotImplemented

def __ne__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid != other.uuid
return NotImplemented

def __gt__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid > other.uuid
return NotImplemented

def __lt__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid < other.uuid
return NotImplemented

def __ge__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid >= other.uuid
return NotImplemented

def __le__(self, other: object) -> bool:
if isinstance(other, ArithUUID):
return self.uuid <= other.uuid
return NotImplemented


Expand Down
Loading