Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit cd0e5ed

Browse files
committed
no full index scans
1 parent ee35ee2 commit cd0e5ed

File tree

2 files changed

+302
-0
lines changed

2 files changed

+302
-0
lines changed

data_diff/diff_tables.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
"""Provides classes for performing a table diff
22
"""
33

4+
<<<<<<< HEAD
45
from collections import defaultdict
56
from typing import List, Tuple
67
import logging
8+
=======
9+
from typing import List, Tuple, Iterator, Literal
10+
import logging
11+
import datetime
12+
>>>>>>> 8914eaf (no full index scans)
713

814
from runtype import dataclass
915

@@ -38,6 +44,19 @@ def __post_init__(self):
3844
if not self.update_column and (self.min_time or self.max_time):
3945
raise ValueError("Error: min_time/max_time feature requires to specify 'update_column'")
4046

47+
<<<<<<< HEAD
48+
=======
49+
# This will only happen on the first TableSegment
50+
if self.start_key is None or self.end_key is None:
51+
select = self._make_select(columns=[f"min({self.key_column})", f"max({self.key_column})"])
52+
res = self.database.query(select, Tuple)[0] or (0, 0)
53+
54+
if self.start_key is None:
55+
self.start_key = res[0]
56+
if self.end_key is None:
57+
self.end_key = res[1]
58+
59+
>>>>>>> 8914eaf (no full index scans)
4160
def _make_key_range(self):
4261
if self.start_key is not None:
4362
yield Compare("<=", str(self.start_key), self.key_column)
@@ -50,7 +69,12 @@ def _make_update_range(self):
5069
if self.max_time is not None:
5170
yield Compare("<", self.update_column, Time(self.max_time))
5271

72+
<<<<<<< HEAD
5373
def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None):
74+
=======
75+
def _make_select(self, *, table=None, columns=None, where=None,
76+
group_by=None, order_by=None, where_or=None):
77+
>>>>>>> 8914eaf (no full index scans)
5478
if columns is None:
5579
columns = [self.key_column]
5680
where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where])
@@ -60,6 +84,10 @@ def _make_select(self, *, table=None, columns=None, where=None, group_by=None, o
6084
where=where,
6185
columns=columns,
6286
group_by=group_by,
87+
<<<<<<< HEAD
88+
=======
89+
where_or=where_or,
90+
>>>>>>> 8914eaf (no full index scans)
6391
order_by=order_by,
6492
)
6593

@@ -68,13 +96,37 @@ def get_values(self) -> list:
6896
select = self._make_select(columns=self._relevant_columns)
6997
return self.database.query(select, List[Tuple])
7098

99+
<<<<<<< HEAD
71100
def choose_checkpoints(self, count: int) -> List[DbKey]:
72101
"Suggests a bunch of evenly-spaced checkpoints to split by"
73102
ratio = int(self.count / count)
74103
assert ratio > 1
75104
skip = f"mod(idx, {ratio}) = 0"
76105
select = self._make_select(table=Enum(self.table_path, order_by=self.key_column), where=skip)
77106
return self.database.query(select, List[int])
107+
=======
108+
def choose_checkpoints(self, bisection_factor: int) -> List[DbKey]:
109+
"Suggests a bunch of evenly-spaced checkpoints to split by"
110+
gap = round((self.end_key - self.start_key + 1) / bisection_factor)
111+
assert gap >= 1
112+
113+
checkpoints = [self.start_key + gap]
114+
for i in range(bisection_factor - 1):
115+
checkpoints.append(checkpoints[i] + gap)
116+
117+
# The _make_select will ensure it's still within the valid key space!
118+
lookaround = 1000
119+
120+
columns = []
121+
where_or = []
122+
for i in range(bisection_factor - 1):
123+
columns.append(f"MAX(CASE WHEN id >= {checkpoints[i]-lookaround} AND id < {checkpoints[i]} THEN id ELSE -1 END)")
124+
where_or.append(f"(id >= {checkpoints[i]-lookaround} AND id < {checkpoints[i]})")
125+
126+
select = self._make_select(columns=columns, where_or=where_or)
127+
real_checkpoints = self.database.query(select, List[Tuple])
128+
return list(real_checkpoints[0])
129+
>>>>>>> 8914eaf (no full index scans)
78130

79131
def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]:
80132
"Takes a list of potential checkpoints and returns those that exist"
@@ -97,43 +149,75 @@ def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment
97149

98150
return tables
99151

152+
<<<<<<< HEAD
100153
## Calculate checksums in one go, to prevent repetitive individual calls
101154
# selects = [t._make_select(columns=[Checksum(self._relevant_columns)]) for t in tables]
102155
# res = self.database.query(Select(columns=selects), list)
103156
# checksums ,= res
104157
# assert len(checksums) == len(checkpoints) + 1
105158
# return [t.new(_checksum=checksum) for t, checksum in safezip(tables, checksums)]
106159

160+
=======
161+
>>>>>>> 8914eaf (no full index scans)
107162
def new(self, _count=None, _checksum=None, **kwargs) -> "TableSegment":
108163
"""Using new() creates a copy of the instance using 'replace()', and makes sure the cache is reset"""
109164
return self.replace(_count=None, _checksum=None, **kwargs)
110165

166+
<<<<<<< HEAD
111167
@property
112168
def count(self) -> int:
113169
if self._count is None:
114170
self._count = self.database.query(self._make_select(columns=[Count()]), int)
171+
=======
172+
def __repr__(self):
173+
return f"{type(self.database).__name__}/{', '.join(self.table_path)}"
174+
175+
@property
176+
def count(self) -> int:
177+
if self._count is None:
178+
raise ValueError("You should always get the count after the checksum to avoid another index scan")
179+
>>>>>>> 8914eaf (no full index scans)
115180
return self._count
116181

117182
@property
118183
def _relevant_columns(self) -> List[str]:
184+
<<<<<<< HEAD
119185
return (
120186
[self.key_column]
121187
+ ([self.update_column] if self.update_column is not None else [])
122188
+ list(self.extra_columns)
123189
)
190+
=======
191+
return list(set(
192+
[self.key_column]
193+
+ ([self.update_column] if self.update_column is not None else [])
194+
+ list(self.extra_columns)
195+
))
196+
>>>>>>> 8914eaf (no full index scans)
124197

125198
@property
126199
def checksum(self) -> int:
127200
if self._checksum is None:
201+
<<<<<<< HEAD
128202
self._checksum = (
129203
self.database.query(self._make_select(columns=[Checksum(self._relevant_columns)]), int) or 0
130204
)
205+
=======
206+
# Get the count in the same index pass. Much cheaper than doing it
207+
# separately.
208+
select = self._make_select(columns=[Count(), Checksum(self._relevant_columns)])
209+
result = self.database.query(select, Tuple)
210+
self._checksum = int(result[0][1])
211+
self._count = result[0][0]
212+
213+
>>>>>>> 8914eaf (no full index scans)
131214
return self._checksum
132215

133216

134217
def diff_sets(a: set, b: set) -> iter:
135218
s1 = set(a)
136219
s2 = set(b)
220+
<<<<<<< HEAD
137221
d = defaultdict(list)
138222

139223
# The first item is always the key (see TableDiffer._relevant_columns)
@@ -147,6 +231,15 @@ def diff_sets(a: set, b: set) -> iter:
147231

148232

149233
DiffResult = iter # Iterator[Tuple[Literal["+", "-"], tuple]]
234+
=======
235+
for i in s1 - s2:
236+
yield "+", i
237+
for i in s2 - s1:
238+
yield "-", i
239+
240+
241+
DiffResult = Iterator[Tuple[Literal["+", "-"], tuple]]
242+
>>>>>>> 8914eaf (no full index scans)
150243

151244

152245
@dataclass
@@ -160,7 +253,11 @@ class TableDiffer:
160253
"""
161254

162255
bisection_factor: int = 32 # Into how many segments to bisect per iteration
256+
<<<<<<< HEAD
163257
bisection_threshold: int = 1024**2 # When should we stop bisecting and compare locally (in row count)
258+
=======
259+
bisection_threshold: int = 10000 # When should we stop bisecting and compare locally (in row count)
260+
>>>>>>> 8914eaf (no full index scans)
164261
debug: bool = False
165262

166263
def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
@@ -177,6 +274,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
177274
raise ValueError("Must have at least two segments per iteration")
178275

179276
logger.info(
277+
<<<<<<< HEAD
180278
f"Diffing tables of size {table1.count} and {table2.count} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
181279
)
182280

@@ -222,3 +320,60 @@ def _diff_tables(self, table1, table2, level=0):
222320
if t1.checksum != t2.checksum:
223321
# Apply recursively
224322
yield from self._diff_tables(t1, t2, level + 1)
323+
=======
324+
f"Diffing tables {repr(table1)} and {repr(table2)} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
325+
)
326+
327+
return self._diff_tables(table1, table2)
328+
329+
def _diff_tables(self, table1, table2, level=0, bisection_factor=None):
330+
if bisection_factor is None:
331+
bisection_factor = self.bisection_factor
332+
if level > 50:
333+
raise Exception("Recursing too far; likely infinite loop")
334+
335+
# TODO: As an optimization, get an approximate count here from the
336+
# database's information tables (if available), and if it's roughly
337+
# below the threshold, then allow getting the values on the first pass.
338+
339+
# We only check beyond level > 0, because otherwise we might scan the
340+
# entire index in one query. For large tables with billions of rows, we
341+
# need to split by the `bisection_factor`.
342+
if level > 0:
343+
count1 = table1.count
344+
count2 = table2.count
345+
# TODO: MAX KEY - MIN_KEY + 1 too?
346+
347+
# If count is below the threshold, just download and compare the columns locally
348+
# This saves time, as bisection speed is limited by ping and query performance.
349+
if count1 < self.bisection_threshold and count2 < self.bisection_threshold:
350+
rows1 = table1.get_values()
351+
rows2 = table2.get_values()
352+
diff = list(diff_sets(rows1, rows2))
353+
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
354+
yield from diff
355+
return
356+
357+
# Find mutual checkpoints between the two tables
358+
checkpoints = table1.choose_checkpoints(bisection_factor)
359+
assert checkpoints
360+
mutual_checkpoints = table2.find_checkpoints([Value(c) for c in checkpoints])
361+
mutual_checkpoints = list(set(mutual_checkpoints)) # Duplicate values are a problem!
362+
mutual_checkpoints.sort()
363+
# print(f"level={level} cp={checkpoints} mc={mutual_checkpoints} bf={bisection_factor} t1start_key={table1.start_key} t1end_key={table1.end_key} t2_start_key={table2.start_key} t2_end_key={table2.end_key}")
364+
logger.debug(". " * level + f"Found {len(mutual_checkpoints)} mutual checkpoints (out of {len(checkpoints)}) origin={checkpoints} mutual={mutual_checkpoints}")
365+
if not mutual_checkpoints:
366+
raise Exception("Tables are too different.")
367+
368+
369+
# Create new instances of TableSegment between each checkpoint
370+
segmented1 = table1.segment_by_checkpoints(mutual_checkpoints)
371+
segmented2 = table2.segment_by_checkpoints(mutual_checkpoints)
372+
# print(segmented1)
373+
374+
# Compare each pair of corresponding segments between table1 and table2
375+
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)):
376+
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} keys={t1.start_key}..{t1.end_key}")
377+
if t1.checksum != t2.checksum:
378+
yield from self._diff_tables(t1, t2, level + 1, max(int(bisection_factor / 2), 2))
379+
>>>>>>> 8914eaf (no full index scans)

0 commit comments

Comments
 (0)