Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8f55fb4

Browse files
authored
Merge pull request #813 from datafold/uuid-misdetection
Do not detect MD5s as UUIDs, and preserve UUID casing for UUID PKs
2 parents 0978570 + 6886ecc commit 8f55fb4

File tree

7 files changed

+271
-28
lines changed

7 files changed

+271
-28
lines changed

data_diff/abcs/database_types.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def python_type(self) -> type:
182182
"Return the equivalent Python type of the key"
183183

184184
def make_value(self, value):
185+
if isinstance(value, self.python_type):
186+
return value
185187
return self.python_type(value)
186188

187189

@@ -217,7 +219,14 @@ class Native_UUID(ColType_UUID):
217219

218220
@attrs.define(frozen=True)
219221
class String_UUID(ColType_UUID, StringType):
220-
pass
222+
# Case is important for UUIDs stored as regular string, not native UUIDs stored as numbers.
223+
# We slice them internally as numbers, but render them back to SQL as lower/upper case.
224+
# None means we do not know for sure, behave as with False, but it might be unreliable.
225+
lowercase: Optional[bool] = None
226+
uppercase: Optional[bool] = None
227+
228+
def make_value(self, v: str) -> ArithUUID:
229+
return self.python_type(v, lowercase=self.lowercase, uppercase=self.uppercase)
221230

222231

223232
@attrs.define(frozen=True)
@@ -230,9 +239,6 @@ def test_value(value: str) -> bool:
230239
except ValueError:
231240
return False
232241

233-
def make_value(self, value):
234-
return self.python_type(value)
235-
236242

237243
@attrs.define(frozen=True)
238244
class String_VaryingAlphanum(String_Alphanum):
@@ -244,6 +250,8 @@ class String_FixedAlphanum(String_Alphanum):
244250
length: int
245251

246252
def make_value(self, value):
253+
if isinstance(value, self.python_type):
254+
return value
247255
if len(value) != self.length:
248256
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")
249257
return self.python_type(value, max_len=self.length)

data_diff/databases/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from data_diff.abcs.compiler import AbstractCompiler, Compilable
2121
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
2222
from data_diff.schema import RawColumnInfo
23-
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
23+
from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip
2424
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
2525
from data_diff.queries.ast_classes import (
2626
Alias,
@@ -248,6 +248,9 @@ def _compile(self, compiler: Compiler, elem) -> str:
248248
return self.timestamp_value(elem)
249249
elif isinstance(elem, bytes):
250250
return f"b'{elem.decode()}'"
251+
elif isinstance(elem, ArithUUID):
252+
s = f"'{elem.uuid}'"
253+
return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s
251254
elif isinstance(elem, ArithString):
252255
return f"'{elem}'"
253256
assert False, elem
@@ -681,8 +684,10 @@ def _constant_value(self, v):
681684
return f"'{v}'"
682685
elif isinstance(v, datetime):
683686
return self.timestamp_value(v)
684-
elif isinstance(v, UUID):
687+
elif isinstance(v, UUID): # probably unused anymore in favour of ArithUUID
685688
return f"'{v}'"
689+
elif isinstance(v, ArithUUID):
690+
return f"'{v.uuid}'"
686691
elif isinstance(v, decimal.Decimal):
687692
return str(v)
688693
elif isinstance(v, bytearray):
@@ -1110,7 +1115,10 @@ def _refine_coltypes(
11101115
)
11111116
else:
11121117
assert col_name in col_dict
1113-
col_dict[col_name] = String_UUID()
1118+
col_dict[col_name] = String_UUID(
1119+
lowercase=all(s == s.lower() for s in uuid_samples),
1120+
uppercase=all(s == s.upper() for s in uuid_samples),
1121+
)
11141122
continue
11151123

11161124
if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)

data_diff/diff_tables.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, in
300300
# Start with the first completed value, so we don't waste time waiting
301301
min_key1, max_key1 = self._parse_key_range_result(key_types1, next(key_ranges))
302302

303-
btable1, btable2 = [t.new_key_bounds(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
303+
btable1 = table1.new_key_bounds(min_key=min_key1, max_key=max_key1, key_types=key_types1)
304+
btable2 = table2.new_key_bounds(min_key=min_key1, max_key=max_key1, key_types=key_types2)
304305

305306
logger.info(
306307
f"Diffing segments at key-range: {btable1.min_key}..{btable2.max_key}. "
@@ -324,16 +325,18 @@ def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, in
324325
# └──┴──────┴──┘
325326
# Overall, the max number of new regions in this 2nd pass is 3^|k| - 1
326327

327-
min_key2, max_key2 = self._parse_key_range_result(key_types1, next(key_ranges))
328+
# Note: python types can be the same, but the rendering parameters (e.g. casing) can differ.
329+
min_key2, max_key2 = self._parse_key_range_result(key_types2, next(key_ranges))
328330

329331
points = [list(sorted(p)) for p in safezip(min_key1, min_key2, max_key1, max_key2)]
330332
box_mesh = create_mesh_from_points(*points)
331333

332334
new_regions = [(p1, p2) for p1, p2 in box_mesh if p1 < p2 and not (p1 >= min_key1 and p2 <= max_key1)]
333335

334336
for p1, p2 in new_regions:
335-
extra_tables = [t.new_key_bounds(min_key=p1, max_key=p2) for t in (table1, table2)]
336-
ti.submit(self._bisect_and_diff_segments, ti, *extra_tables, info_tree, priority=999)
337+
extra_table1 = table1.new_key_bounds(min_key=p1, max_key=p2, key_types=key_types1)
338+
extra_table2 = table2.new_key_bounds(min_key=p1, max_key=p2, key_types=key_types2)
339+
ti.submit(self._bisect_and_diff_segments, ti, extra_table1, extra_table2, info_tree, priority=999)
337340

338341
return ti
339342

data_diff/table_segment.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Container, Dict, List, Optional, Tuple
2+
from typing import Container, Dict, List, Optional, Sequence, Tuple
33
import logging
44
from itertools import product
55

@@ -9,7 +9,7 @@
99
from data_diff.utils import safezip, Vector
1010
from data_diff.utils import ArithString, split_space
1111
from data_diff.databases.base import Database
12-
from data_diff.abcs.database_types import DbPath, DbKey, DbTime
12+
from data_diff.abcs.database_types import DbPath, DbKey, DbTime, IKey
1313
from data_diff.schema import RawColumnInfo, Schema, create_schema
1414
from data_diff.queries.extras import Checksum
1515
from data_diff.queries.api import Count, SKIP, table, this, Expr, min_, max_, Code
@@ -205,7 +205,7 @@ def new(self, **kwargs) -> Self:
205205
"""Creates a copy of the instance using 'replace()'"""
206206
return attrs.evolve(self, **kwargs)
207207

208-
def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
208+
def new_key_bounds(self, min_key: Vector, max_key: Vector, *, key_types: Optional[Sequence[IKey]] = None) -> Self:
209209
if self.min_key is not None:
210210
assert self.min_key <= min_key, (self.min_key, min_key)
211211
assert self.min_key < max_key
@@ -214,6 +214,13 @@ def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
214214
assert min_key < self.max_key
215215
assert max_key <= self.max_key
216216

217+
# If asked, enforce the PKs to proper types, mainly to meta-params of the relevant side,
218+
# so that we do not leak e.g. casing of UUIDs from side A to side B and vice versa.
219+
# If not asked, keep the meta-params of the keys as is (assume them already casted).
220+
if key_types is not None:
221+
min_key = Vector(type.make_value(val) for type, val in safezip(key_types, min_key))
222+
max_key = Vector(type.make_value(val) for type, val in safezip(key_types, max_key))
223+
217224
return attrs.evolve(self, min_key=min_key, max_key=max_key)
218225

219226
@property

data_diff/utils.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,14 @@ def safezip(*args):
4343
return zip(*args)
4444

4545

46-
def is_uuid(u):
46+
UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", re.I)
47+
48+
49+
def is_uuid(u: str) -> bool:
50+
# E.g., hashlib.md5(b'hello') is a 32-letter hex number, but not an UUID.
51+
# It would fail UUID-like comparison (< & >) because of casing and dashes.
52+
if not UUID_PATTERN.fullmatch(u):
53+
return False
4754
try:
4855
UUID(u)
4956
except ValueError:
@@ -128,23 +135,75 @@ def range(self, other: "ArithString", count: int) -> List[Self]:
128135
return [self.new(int=i) for i in checkpoints]
129136

130137

131-
# @attrs.define # not as long as it inherits from UUID
132-
class ArithUUID(UUID, ArithString):
138+
def _any_to_uuid(v: Union[str, int, UUID, "ArithUUID"]) -> UUID:
139+
if isinstance(v, ArithUUID):
140+
return v.uuid
141+
elif isinstance(v, UUID):
142+
return v
143+
elif isinstance(v, str):
144+
return UUID(v)
145+
elif isinstance(v, int):
146+
return UUID(int=v)
147+
else:
148+
raise ValueError(f"Cannot convert a value to UUID: {v!r}")
149+
150+
151+
@attrs.define(frozen=True, eq=False, order=False)
152+
class ArithUUID(ArithString):
133153
"A UUID that supports basic arithmetic (add, sub)"
134154

155+
uuid: UUID = attrs.field(converter=_any_to_uuid)
156+
lowercase: Optional[bool] = None
157+
uppercase: Optional[bool] = None
158+
159+
def range(self, other: "ArithUUID", count: int) -> List[Self]:
160+
assert isinstance(other, ArithUUID)
161+
checkpoints = split_space(self.uuid.int, other.uuid.int, count)
162+
return [attrs.evolve(self, uuid=i) for i in checkpoints]
163+
135164
def __int__(self):
136-
return self.int
165+
return self.uuid.int
137166

138167
def __add__(self, other: int) -> Self:
139168
if isinstance(other, int):
140-
return self.new(int=self.int + other)
169+
return attrs.evolve(self, uuid=self.uuid.int + other)
141170
return NotImplemented
142171

143-
def __sub__(self, other: Union[UUID, int]):
172+
def __sub__(self, other: Union["ArithUUID", int]):
144173
if isinstance(other, int):
145-
return self.new(int=self.int - other)
146-
elif isinstance(other, UUID):
147-
return self.int - other.int
174+
return attrs.evolve(self, uuid=self.uuid.int - other)
175+
elif isinstance(other, ArithUUID):
176+
return self.uuid.int - other.uuid.int
177+
return NotImplemented
178+
179+
def __eq__(self, other: object) -> bool:
180+
if isinstance(other, ArithUUID):
181+
return self.uuid == other.uuid
182+
return NotImplemented
183+
184+
def __ne__(self, other: object) -> bool:
185+
if isinstance(other, ArithUUID):
186+
return self.uuid != other.uuid
187+
return NotImplemented
188+
189+
def __gt__(self, other: object) -> bool:
190+
if isinstance(other, ArithUUID):
191+
return self.uuid > other.uuid
192+
return NotImplemented
193+
194+
def __lt__(self, other: object) -> bool:
195+
if isinstance(other, ArithUUID):
196+
return self.uuid < other.uuid
197+
return NotImplemented
198+
199+
def __ge__(self, other: object) -> bool:
200+
if isinstance(other, ArithUUID):
201+
return self.uuid >= other.uuid
202+
return NotImplemented
203+
204+
def __le__(self, other: object) -> bool:
205+
if isinstance(other, ArithUUID):
206+
return self.uuid <= other.uuid
148207
return NotImplemented
149208

150209

0 commit comments

Comments
 (0)