1
1
"""Provides classes for performing a table diff
2
2
"""
3
3
4
+ from operator import attrgetter
4
5
from typing import List , Tuple , Iterator , Literal
5
6
import logging
7
+ from concurrent .futures import ThreadPoolExecutor
6
8
7
9
from runtype import dataclass
8
10
11
13
12
14
logger = logging .getLogger ("diff_tables" )
13
15
16
+ # Global task pool for the table diff algorithm
17
+ g_task_pool = ThreadPoolExecutor ()
18
+
14
19
15
20
def safezip (* args ):
16
21
"zip but makes sure all sequences are the same length"
@@ -142,6 +147,10 @@ def diff_sets(a: set, b: set) -> iter:
142
147
DiffResult = Iterator [Tuple [Literal ["+" , "-" ], tuple ]]
143
148
144
149
150
+ def precalc_attr (attr , iter ):
151
+ return list (g_task_pool .map (attrgetter (attr ), iter ))
152
+
153
+
145
154
@dataclass
146
155
class TableDiffer :
147
156
"""Finds the diff between two SQL tables
@@ -165,22 +174,26 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
165
174
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
166
175
"""
167
176
if self .bisection_factor >= self .bisection_threshold :
168
- raise ValueError ("Incorrect param values" )
177
+ raise ValueError ("Incorrect param values (bisection factor must be lower than threshold) " )
169
178
if self .bisection_factor < 2 :
170
- raise ValueError ("Must have at least two segments per iteration" )
179
+ raise ValueError ("Must have at least two segments per iteration (i.e. bisection_factor >= 2) " )
171
180
172
- logger .info (
173
- f"Diffing tables of size { table1 .count } and { table2 .count } | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } ."
174
- )
181
+ return self ._diff_tables (table1 , table2 )
175
182
176
- if table1 . checksum == table2 . checksum :
177
- return [] # No differences
183
+ def _diff_tables ( self , table1 , table2 , level = 0 , segment_index = None , segment_count = None ) :
184
+ count1 , count2 = precalc_attr ( "count" , [ table1 , table2 ])
178
185
179
- return self ._diff_tables (table1 , table2 )
186
+ if segment_index :
187
+ logger .info (". " * level + f"Diffing segment { segment_index } /{ segment_count } of size { count1 } and { count2 } " )
188
+ else :
189
+ logger .info (
190
+ f"Diffing tables of size { table1 .count } and { table2 .count } | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } ."
191
+ )
192
+
193
+ checksum1 , checksum2 = precalc_attr ("checksum" , [table1 , table2 ]) # Calculate checksum in parallel
180
194
181
- def _diff_tables (self , table1 , table2 , level = 0 ):
182
- count1 = table1 .count
183
- count2 = table2 .count
195
+ if checksum1 == checksum2 :
196
+ return # No differences
184
197
185
198
# If count is below the threshold, just download and compare the columns locally
186
199
# This saves time, as bisection speed is limited by ping and query performance.
@@ -204,14 +217,18 @@ def _diff_tables(self, table1, table2, level=0):
204
217
# Create new instances of TableSegment between each checkpoint
205
218
segmented1 = table1 .segment_by_checkpoints (mutual_checkpoints )
206
219
segmented2 = table2 .segment_by_checkpoints (mutual_checkpoints )
220
+
207
221
if self .debug :
208
222
logger .debug ("Performing sanity tests for chosen segments (assert sum of fragments == whole)" )
223
+ precalc_attr ("count" , segmented1 + segmented2 )
209
224
assert count1 == sum (s .count for s in segmented1 )
210
225
assert count2 == sum (s .count for s in segmented2 )
211
226
212
- # Compare each pair of corresponding segments between table1 and table2
213
- for i , (t1 , t2 ) in enumerate (safezip (segmented1 , segmented2 )):
214
- logger .info (". " * level + f"Diffing segment { i + 1 } /{ len (segmented1 )} of size { t1 .count } and { t2 .count } " )
215
- if t1 .checksum != t2 .checksum :
216
- # Apply recursively
217
- yield from self ._diff_tables (t1 , t2 , level + 1 )
227
+ # Recursively compare each pair of corresponding segments between table1 and table2
228
+ diff_iters = [
229
+ self ._diff_tables (t1 , t2 , level + 1 , i + 1 , len (segmented1 ))
230
+ for i , (t1 , t2 ) in enumerate (safezip (segmented1 , segmented2 ))
231
+ ]
232
+
233
+ for res in g_task_pool .map (list , diff_iters ):
234
+ yield from res
0 commit comments