1
1
"""Provides classes for performing a table diff
2
2
"""
3
3
4
+ < << << << HEAD
4
5
from collections import defaultdict
5
6
from typing import List , Tuple
6
7
import logging
8
+ == == == =
9
+ from typing import List , Tuple , Iterator , Literal
10
+ import logging
11
+ import datetime
12
+ > >> >> >> 8914 eaf (no full index scans )
7
13
8
14
from runtype import dataclass
9
15
@@ -38,6 +44,19 @@ def __post_init__(self):
38
44
if not self .update_column and (self .min_time or self .max_time ):
39
45
raise ValueError ("Error: min_time/max_time feature requires to specify 'update_column'" )
40
46
47
+ < << << << HEAD
48
+ == == == =
49
+ # This will only happen on the first TableSegment
50
+ if self .start_key is None or self .end_key is None :
51
+ select = self ._make_select (columns = [f"min({ self .key_column } )" , f"max({ self .key_column } )" ])
52
+ res = self .database .query (select , Tuple )[0 ] or (0 , 0 )
53
+
54
+ if self .start_key is None :
55
+ self .start_key = res [0 ]
56
+ if self .end_key is None :
57
+ self .end_key = res [1 ]
58
+
59
+ > >> >> >> 8914 eaf (no full index scans )
41
60
def _make_key_range (self ):
42
61
if self .start_key is not None :
43
62
yield Compare ("<=" , str (self .start_key ), self .key_column )
@@ -50,7 +69,12 @@ def _make_update_range(self):
50
69
if self .max_time is not None :
51
70
yield Compare ("<" , self .update_column , Time (self .max_time ))
52
71
72
+ < << << << HEAD
53
73
def _make_select (self , * , table = None , columns = None , where = None , group_by = None , order_by = None ):
74
+ == == == =
75
+ def _make_select (self , * , table = None , columns = None , where = None ,
76
+ group_by = None , order_by = None , where_or = None ):
77
+ >> >> >> > 8914 eaf (no full index scans )
54
78
if columns is None :
55
79
columns = [self .key_column ]
56
80
where = list (self ._make_key_range ()) + list (self ._make_update_range ()) + ([] if where is None else [where ])
@@ -60,6 +84,10 @@ def _make_select(self, *, table=None, columns=None, where=None, group_by=None, o
60
84
where = where ,
61
85
columns = columns ,
62
86
group_by = group_by ,
87
+ << << << < HEAD
88
+ == == == =
89
+ where_or = where_or ,
90
+ >> >> >> > 8914 eaf (no full index scans )
63
91
order_by = order_by ,
64
92
)
65
93
@@ -68,13 +96,37 @@ def get_values(self) -> list:
68
96
select = self ._make_select (columns = self ._relevant_columns )
69
97
return self .database .query (select , List [Tuple ])
70
98
99
+ << << << < HEAD
71
100
def choose_checkpoints (self , count : int ) -> List [DbKey ]:
72
101
"Suggests a bunch of evenly-spaced checkpoints to split by"
73
102
ratio = int (self .count / count )
74
103
assert ratio > 1
75
104
skip = f"mod(idx, { ratio } ) = 0"
76
105
select = self ._make_select (table = Enum (self .table_path , order_by = self .key_column ), where = skip )
77
106
return self .database .query (select , List [int ])
107
+ == == == =
108
+ def choose_checkpoints (self , bisection_factor : int ) -> List [DbKey ]:
109
+ "Suggests a bunch of evenly-spaced checkpoints to split by"
110
+ gap = round ((self .end_key - self .start_key + 1 ) / bisection_factor )
111
+ assert gap >= 1
112
+
113
+ checkpoints = [self .start_key + gap ]
114
+ for i in range (bisection_factor - 1 ):
115
+ checkpoints .append (checkpoints [i ] + gap )
116
+
117
+ # The _make_select will ensure it's still within the valid key space!
118
+ lookaround = 1000
119
+
120
+ columns = []
121
+ where_or = []
122
+ for i in range (bisection_factor - 1 ):
123
+ columns .append (f"MAX(CASE WHEN id >= { checkpoints [i ]- lookaround } AND id < { checkpoints [i ]} THEN id ELSE -1 END)" )
124
+ where_or .append (f"(id >= { checkpoints [i ]- lookaround } AND id < { checkpoints [i ]} )" )
125
+
126
+ select = self ._make_select (columns = columns , where_or = where_or )
127
+ real_checkpoints = self .database .query (select , List [Tuple ])
128
+ return list (real_checkpoints [0 ])
129
+ >> >> >> > 8914 eaf (no full index scans )
78
130
79
131
def find_checkpoints (self , checkpoints : List [DbKey ]) -> List [DbKey ]:
80
132
"Takes a list of potential checkpoints and returns those that exist"
@@ -97,43 +149,75 @@ def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment
97
149
98
150
return tables
99
151
152
+ < << << << HEAD
100
153
## Calculate checksums in one go, to prevent repetitive individual calls
101
154
# selects = [t._make_select(columns=[Checksum(self._relevant_columns)]) for t in tables]
102
155
# res = self.database.query(Select(columns=selects), list)
103
156
# checksums ,= res
104
157
# assert len(checksums) == len(checkpoints) + 1
105
158
# return [t.new(_checksum=checksum) for t, checksum in safezip(tables, checksums)]
106
159
160
+ == == == =
161
+ >> >> >> > 8914 eaf (no full index scans )
107
162
def new (self , _count = None , _checksum = None , ** kwargs ) -> "TableSegment" :
108
163
"""Using new() creates a copy of the instance using 'replace()', and makes sure the cache is reset"""
109
164
return self .replace (_count = None , _checksum = None , ** kwargs )
110
165
166
+ < << << << HEAD
111
167
@property
112
168
def count (self ) -> int :
113
169
if self ._count is None :
114
170
self ._count = self .database .query (self ._make_select (columns = [Count ()]), int )
171
+ == == == =
172
+ def __repr__ (self ):
173
+ return f"{ type (self .database ).__name__ } /{ ', ' .join (self .table_path )} "
174
+
175
+ @property
176
+ def count (self ) -> int :
177
+ if self ._count is None :
178
+ raise ValueError ("You should always get the count after the checksum to avoid another index scan" )
179
+ >> >> >> > 8914 eaf (no full index scans )
115
180
return self ._count
116
181
117
182
@property
118
183
def _relevant_columns (self ) -> List [str ]:
184
+ < << << << HEAD
119
185
return (
120
186
[self .key_column ]
121
187
+ ([self .update_column ] if self .update_column is not None else [])
122
188
+ list (self .extra_columns )
123
189
)
190
+ == == == =
191
+ return list (set (
192
+ [self .key_column ]
193
+ + ([self .update_column ] if self .update_column is not None else [])
194
+ + list (self .extra_columns )
195
+ ))
196
+ >> >> >> > 8914 eaf (no full index scans )
124
197
125
198
@property
126
199
def checksum (self ) -> int :
127
200
if self ._checksum is None :
201
+ < << << << HEAD
128
202
self ._checksum = (
129
203
self .database .query (self ._make_select (columns = [Checksum (self ._relevant_columns )]), int ) or 0
130
204
)
205
+ == == == =
206
+ # Get the count in the same index pass. Much cheaper than doing it
207
+ # separately.
208
+ select = self ._make_select (columns = [Count (), Checksum (self ._relevant_columns )])
209
+ result = self .database .query (select , Tuple )
210
+ self ._checksum = int (result [0 ][1 ])
211
+ self ._count = result [0 ][0 ]
212
+
213
+ >> >> >> > 8914 eaf (no full index scans )
131
214
return self ._checksum
132
215
133
216
134
217
def diff_sets (a : set , b : set ) -> iter :
135
218
s1 = set (a )
136
219
s2 = set (b )
220
+ < << << << HEAD
137
221
d = defaultdict (list )
138
222
139
223
# The first item is always the key (see TableDiffer._relevant_columns)
@@ -147,6 +231,15 @@ def diff_sets(a: set, b: set) -> iter:
147
231
148
232
149
233
DiffResult = iter # Iterator[Tuple[Literal["+", "-"], tuple]]
234
+ == == == =
235
+ for i in s1 - s2 :
236
+ yield "+" , i
237
+ for i in s2 - s1 :
238
+ yield "-" , i
239
+
240
+
241
+ DiffResult = Iterator [Tuple [Literal ["+" , "-" ], tuple ]]
242
+ > >> >> >> 8914 eaf (no full index scans )
150
243
151
244
152
245
@dataclass
@@ -160,7 +253,11 @@ class TableDiffer:
160
253
"""
161
254
162
255
bisection_factor : int = 32 # Into how many segments to bisect per iteration
256
+ < << << << HEAD
163
257
bisection_threshold : int = 1024 ** 2 # When should we stop bisecting and compare locally (in row count)
258
+ == == == =
259
+ bisection_threshold : int = 10000 # When should we stop bisecting and compare locally (in row count)
260
+ >> >> >> > 8914 eaf (no full index scans )
164
261
debug : bool = False
165
262
166
263
def diff_tables (self , table1 : TableSegment , table2 : TableSegment ) -> DiffResult :
@@ -177,6 +274,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
177
274
raise ValueError ("Must have at least two segments per iteration" )
178
275
179
276
logger .info (
277
+ << << << < HEAD
180
278
f"Diffing tables of size { table1 .count } and { table2 .count } | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } ."
181
279
)
182
280
@@ -222,3 +320,60 @@ def _diff_tables(self, table1, table2, level=0):
222
320
if t1 .checksum != t2 .checksum :
223
321
# Apply recursively
224
322
yield from self ._diff_tables (t1 , t2 , level + 1 )
323
+ == == == =
324
+ f"Diffing tables { repr (table1 )} and { repr (table2 )} | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } ."
325
+ )
326
+
327
+ return self ._diff_tables (table1 , table2 )
328
+
329
+ def _diff_tables (self , table1 , table2 , level = 0 , bisection_factor = None ):
330
+ if bisection_factor is None :
331
+ bisection_factor = self .bisection_factor
332
+ if level > 50 :
333
+ raise Exception ("Recursing too far; likely infinite loop" )
334
+
335
+ # TODO: As an optimization, get an approximate count here from the
336
+ # database's information tables (if available), and if it's roughly
337
+ # below the threshold, then allow getting the values on the first pass.
338
+
339
+ # We only check beyond level > 0, because otherwise we might scan the
340
+ # entire index in one query. For large tables with billions of rows, we
341
+ # need to split by the `bisection_factor`.
342
+ if level > 0 :
343
+ count1 = table1 .count
344
+ count2 = table2 .count
345
+ # TODO: MAX KEY - MIN_KEY + 1 too?
346
+
347
+ # If count is below the threshold, just download and compare the columns locally
348
+ # This saves time, as bisection speed is limited by ping and query performance.
349
+ if count1 < self .bisection_threshold and count2 < self .bisection_threshold :
350
+ rows1 = table1 .get_values ()
351
+ rows2 = table2 .get_values ()
352
+ diff = list (diff_sets (rows1 , rows2 ))
353
+ logger .info (". " * level + f"Diff found { len (diff )} different rows." )
354
+ yield from diff
355
+ return
356
+
357
+ # Find mutual checkpoints between the two tables
358
+ checkpoints = table1 .choose_checkpoints (bisection_factor )
359
+ assert checkpoints
360
+ mutual_checkpoints = table2 .find_checkpoints ([Value (c ) for c in checkpoints ])
361
+ mutual_checkpoints = list (set (mutual_checkpoints )) # Duplicate values are a problem!
362
+ mutual_checkpoints .sort ()
363
+ # print(f"level={level} cp={checkpoints} mc={mutual_checkpoints} bf={bisection_factor} t1start_key={table1.start_key} t1end_key={table1.end_key} t2_start_key={table2.start_key} t2_end_key={table2.end_key}")
364
+ logger .debug (". " * level + f"Found { len (mutual_checkpoints )} mutual checkpoints (out of { len (checkpoints )} ) origin={ checkpoints } mutual={ mutual_checkpoints } " )
365
+ if not mutual_checkpoints :
366
+ raise Exception ("Tables are too different." )
367
+
368
+
369
+ # Create new instances of TableSegment between each checkpoint
370
+ segmented1 = table1 .segment_by_checkpoints (mutual_checkpoints )
371
+ segmented2 = table2 .segment_by_checkpoints (mutual_checkpoints )
372
+ # print(segmented1)
373
+
374
+ # Compare each pair of corresponding segments between table1 and table2
375
+ for i , (t1 , t2 ) in enumerate (safezip (segmented1 , segmented2 )):
376
+ logger .info (". " * level + f"Diffing segment { i + 1 } /{ len (segmented1 )} keys={ t1 .start_key } ..{ t1 .end_key } " )
377
+ if t1 .checksum != t2 .checksum :
378
+ yield from self ._diff_tables (t1 , t2 , level + 1 , max (int (bisection_factor / 2 ), 2 ))
379
+ >> >> >> > 8914 eaf (no full index scans )
0 commit comments