Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1c38e18

Browse files
committed
Adapt table-diff to split the workload to different threads
1 parent dab6669 commit 1c38e18

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

data_diff/database.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _query(self, sql_code: str) -> list:
119119

120120

121121
class ThreadedDatabase(Database):
122-
"""Access the database through a singleton thread.
122+
"""Access the database through singleton threads.
123123
124124
Used for database connectors that do not support sharing their connection between different threads.
125125
"""
@@ -137,6 +137,7 @@ def _query(self, sql_code: str):
137137
return r.result()
138138

139139
def _query_in_worker(self, sql_code: str):
140+
"This method runs in a worker thread"
140141
return _query_conn(self.thread_local.conn, sql_code)
141142

142143
@abstractmethod

data_diff/diff_tables.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Provides classes for performing a table diff
22
"""
33

4+
from operator import attrgetter
45
from typing import List, Tuple, Iterator, Literal
56
import logging
7+
from concurrent.futures import ThreadPoolExecutor
68

79
from runtype import dataclass
810

@@ -11,6 +13,9 @@
1113

1214
logger = logging.getLogger("diff_tables")
1315

16+
# Global task pool for the table diff algorithm
17+
g_task_pool = ThreadPoolExecutor()
18+
1419

1520
def safezip(*args):
1621
"zip but makes sure all sequences are the same length"
@@ -142,6 +147,10 @@ def diff_sets(a: set, b: set) -> iter:
142147
DiffResult = Iterator[Tuple[Literal["+", "-"], tuple]]
143148

144149

150+
def precalc_attr(attr, iter):
151+
return list(g_task_pool.map(attrgetter(attr), iter))
152+
153+
145154
@dataclass
146155
class TableDiffer:
147156
"""Finds the diff between two SQL tables
@@ -165,22 +174,26 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
165174
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
166175
"""
167176
if self.bisection_factor >= self.bisection_threshold:
168-
raise ValueError("Incorrect param values")
177+
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
169178
if self.bisection_factor < 2:
170-
raise ValueError("Must have at least two segments per iteration")
179+
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
171180

172-
logger.info(
173-
f"Diffing tables of size {table1.count} and {table2.count} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
174-
)
181+
return self._diff_tables(table1, table2)
175182

176-
if table1.checksum == table2.checksum:
177-
return [] # No differences
183+
def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_count=None):
184+
count1, count2 = precalc_attr("count", [table1, table2])
178185

179-
return self._diff_tables(table1, table2)
186+
if segment_index:
187+
logger.info(". " * level + f"Diffing segment {segment_index}/{segment_count} of size {count1} and {count2}")
188+
else:
189+
logger.info(
190+
f"Diffing tables of size {table1.count} and {table2.count} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
191+
)
192+
193+
checksum1, checksum2 = precalc_attr("checksum", [table1, table2]) # Calculate checksum in parallel
180194

181-
def _diff_tables(self, table1, table2, level=0):
182-
count1 = table1.count
183-
count2 = table2.count
195+
if checksum1 == checksum2:
196+
return # No differences
184197

185198
# If count is below the threshold, just download and compare the columns locally
186199
# This saves time, as bisection speed is limited by ping and query performance.
@@ -204,14 +217,18 @@ def _diff_tables(self, table1, table2, level=0):
204217
# Create new instances of TableSegment between each checkpoint
205218
segmented1 = table1.segment_by_checkpoints(mutual_checkpoints)
206219
segmented2 = table2.segment_by_checkpoints(mutual_checkpoints)
220+
207221
if self.debug:
208222
logger.debug("Performing sanity tests for chosen segments (assert sum of fragments == whole)")
223+
precalc_attr("count", segmented1 + segmented2)
209224
assert count1 == sum(s.count for s in segmented1)
210225
assert count2 == sum(s.count for s in segmented2)
211226

212-
# Compare each pair of corresponding segments between table1 and table2
213-
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)):
214-
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} of size {t1.count} and {t2.count}")
215-
if t1.checksum != t2.checksum:
216-
# Apply recursively
217-
yield from self._diff_tables(t1, t2, level + 1)
227+
# Recursively compare each pair of corresponding segments between table1 and table2
228+
diff_iters = [
229+
self._diff_tables(t1, t2, level + 1, i + 1, len(segmented1))
230+
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2))
231+
]
232+
233+
for res in g_task_pool.map(list, diff_iters):
234+
yield from res

0 commit comments

Comments
 (0)