-
Notifications
You must be signed in to change notification settings - Fork 288
performance: divide raw keyspace into segments, avoid full index scans #32
Changes from all commits
cd0e5ed
a9a30f1
eed5377
871949a
e148e24
186d490
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,9 @@ | |
""" | ||
|
||
from collections import defaultdict | ||
from typing import List, Tuple | ||
import logging | ||
from typing import List, Tuple | ||
import datetime | ||
|
||
from runtype import dataclass | ||
|
||
|
@@ -50,7 +51,8 @@ def _make_update_range(self): | |
if self.max_time is not None: | ||
yield Compare("<", self.update_column, Time(self.max_time)) | ||
|
||
def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): | ||
def _make_select(self, *, table=None, columns=None, where=None, | ||
group_by=None, order_by=None, where_or=None): | ||
if columns is None: | ||
columns = [self.key_column] | ||
where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where]) | ||
|
@@ -60,6 +62,7 @@ def _make_select(self, *, table=None, columns=None, where=None, group_by=None, o | |
where=where, | ||
columns=columns, | ||
group_by=group_by, | ||
where_or=where_or, | ||
order_by=order_by, | ||
) | ||
|
||
|
@@ -68,27 +71,47 @@ def get_values(self) -> list: | |
select = self._make_select(columns=self._relevant_columns) | ||
return self.database.query(select, List[Tuple]) | ||
|
||
def choose_checkpoints(self, count: int) -> List[DbKey]: | ||
"Suggests a bunch of evenly-spaced checkpoints to split by" | ||
ratio = int(self.count / count) | ||
assert ratio > 1 | ||
skip = f"mod(idx, {ratio}) = 0" | ||
select = self._make_select(table=Enum(self.table_path, order_by=self.key_column), where=skip) | ||
return self.database.query(select, List[int]) | ||
def choose_checkpoints(self, bisection_factor: int) -> List[DbKey]: | ||
""" | ||
Choose `bisection_factor - 1` (because of start and end) checkpoints in | ||
the keyspace. | ||
|
||
def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]: | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Takes a list of potential checkpoints and returns those that exist" | ||
where = In(self.key_column, checkpoints) | ||
return self.database.query(self._make_select(where=where), List[int]) | ||
For example, a table of 1000 with `bisection_factor` of 4 would yield | ||
the following checkpoints: | ||
[250, 500, 750] | ||
|
||
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: | ||
"Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints" | ||
Which would yield the following segments: | ||
[1..249, 250..499, 500..749, 750..1000] | ||
""" | ||
|
||
assert self.end_key is not None | ||
assert self.start_key is not None | ||
assert bisection_factor >= 2 | ||
# 1..11 for bisection_factor 2 would mean gap=round(10/2)=5 | ||
# which means checkpoints returns only 1 value: | ||
# [1 + 5 - 1] => [5] | ||
# Then `segment_by_checkpoints` will produce: | ||
# [1..5, 5..11] | ||
gap = round((self.end_key - self.start_key) / (bisection_factor)) | ||
assert gap >= 1 | ||
|
||
proposed_checkpoints = [self.start_key + gap - 1] | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# -2 because we add start + end in `segment_by_checkpoints`! | ||
for i in range(bisection_factor - 2): | ||
proposed_checkpoints.append(proposed_checkpoints[i] + gap - 1) | ||
|
||
if self.start_key and self.end_key: | ||
assert all(self.start_key <= c < self.end_key for c in checkpoints) | ||
checkpoints.sort() | ||
return proposed_checkpoints | ||
|
||
# Calculate sub-segments | ||
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: | ||
"Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints" | ||
# Make sure start_key and end_key are set, for the beginning and end | ||
# they may not be. | ||
assert self.start_key is not None | ||
assert self.end_key is not None | ||
assert all(self.start_key <= c < self.end_key for c in checkpoints) | ||
|
||
# Calculate sub-segments, turns checkpoints such as [250, 500, 750] into | ||
# [1..249, 250..499, 500..749, 750..1000]. | ||
positions = [self.start_key] + checkpoints + [self.end_key] | ||
ranges = list(zip(positions[:-1], positions[1:])) | ||
|
||
|
@@ -97,37 +120,84 @@ def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment | |
|
||
return tables | ||
|
||
## Calculate checksums in one go, to prevent repetitive individual calls | ||
# selects = [t._make_select(columns=[Checksum(self._relevant_columns)]) for t in tables] | ||
# res = self.database.query(Select(columns=selects), list) | ||
# checksums ,= res | ||
# assert len(checksums) == len(checkpoints) + 1 | ||
# return [t.new(_checksum=checksum) for t, checksum in safezip(tables, checksums)] | ||
|
||
def new(self, _count=None, _checksum=None, **kwargs) -> "TableSegment": | ||
"""Using new() creates a copy of the instance using 'replace()', and makes sure the cache is reset""" | ||
return self.replace(_count=None, _checksum=None, **kwargs) | ||
|
||
def __repr__(self): | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return f"{type(self.database).__name__}/{', '.join(self.table_path)}" | ||
|
||
def query_start_key_and_end_key(self) -> Tuple[int, int]: | ||
"""Query database for minimum and maximum key. This is used for setting | ||
the boundaries of the initial, full table table segment.""" | ||
select = self._make_select(columns=[f"min({self.key_column})", f"max({self.key_column})"]) | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
res = self.database.query(select, Tuple)[0] | ||
|
||
start_key = res[0] or 1 | ||
# TableSegments are always exclusive the last key: | ||
# (1..250) => # WHERE i >= 1 AND i < 250 | ||
# Thus, for the very last segment (which is the one where these | ||
# aren't automatically set!) -- we have to add 1. | ||
end_key = res[1] + 1 if res[1] else 1 | ||
|
||
return (start_key, end_key) | ||
|
||
def compute_checksum_and_count(self): | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Query the database for the checksum and count for this segment. Note | ||
that it will _not_ include the `end_key` in this segment, as that's the | ||
beginning of the next segment. | ||
""" | ||
if self.start_key is None or self.end_key is None: | ||
raise ValueError(""" | ||
`start_key` and/or `end_key` are not set. Likely this is because | ||
you didn't call `set_initial_start_key_and_end_key` to get the | ||
min(key) and max(key) from the database for the initial, whole | ||
table segment. | ||
""") | ||
if self._count is not None or self._checksum is not None: | ||
return # already computed | ||
|
||
# Get the count in the same index pass. Much cheaper than doing it | ||
# separately. | ||
select = self._make_select(columns=[Count(), Checksum(self._relevant_columns)]) | ||
result = self.database.query(select, Tuple)[0] | ||
self._count = result[0] if result[1] else 0 | ||
self._checksum = int(result[1]) if result[1] else 0 | ||
|
||
|
||
@property | ||
def count(self) -> int: | ||
if self._count is None: | ||
self._count = self.database.query(self._make_select(columns=[Count()]), int) | ||
raise ValueError(""" | ||
You must call compute_checksum_and_count() before | ||
accessing the count to ensure only one index scan | ||
is performed. | ||
""") | ||
|
||
return self._count | ||
|
||
@property | ||
def _relevant_columns(self) -> List[str]: | ||
return ( | ||
# The user may duplicate columns across -k, -t, -c so we de-dup here. | ||
relevant = list(set( | ||
[self.key_column] | ||
+ ([self.update_column] if self.update_column is not None else []) | ||
+ list(self.extra_columns) | ||
) | ||
)) | ||
relevant.sort() | ||
return relevant | ||
|
||
@property | ||
def checksum(self) -> int: | ||
if self._checksum is None: | ||
self._checksum = ( | ||
self.database.query(self._make_select(columns=[Checksum(self._relevant_columns)]), int) or 0 | ||
) | ||
# Get the count in the same index pass. Much cheaper than doing it | ||
# separately. | ||
select = self._make_select(columns=[Count(), Checksum(self._relevant_columns)]) | ||
result = self.database.query(select, Tuple) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably needs some error handling |
||
self._checksum = int(result[0][1]) | ||
self._count = result[0][0] | ||
|
||
return self._checksum | ||
|
||
|
||
|
@@ -148,7 +218,6 @@ def diff_sets(a: set, b: set) -> iter: | |
|
||
DiffResult = iter # Iterator[Tuple[Literal["+", "-"], tuple]] | ||
|
||
|
||
@dataclass | ||
class TableDiffer: | ||
"""Finds the diff between two SQL tables | ||
|
@@ -160,7 +229,7 @@ class TableDiffer: | |
""" | ||
|
||
bisection_factor: int = 32 # Into how many segments to bisect per iteration | ||
bisection_threshold: int = 1024**2 # When should we stop bisecting and compare locally (in row count) | ||
bisection_threshold: int = 10000 # When should we stop bisecting and compare locally (in row count) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 1M is way too high when we are running this against databases far away from the machine. 10k seems more sensible. We can run some benchmarks in the future, but I think this is a safer default |
||
debug: bool = False | ||
|
||
def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: | ||
|
@@ -176,49 +245,75 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: | |
if self.bisection_factor < 2: | ||
raise ValueError("Must have at least two segments per iteration") | ||
|
||
self.set_initial_start_key_and_end_key(table1, table2) | ||
|
||
logger.info( | ||
f"Diffing tables of size {table1.count} and {table2.count} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}." | ||
f"Diffing tables {repr(table1)} and {repr(table2)} | keys: {table1.start_key}...{table2.end_key} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}." | ||
) | ||
|
||
if table1.checksum == table2.checksum: | ||
return [] # No differences | ||
|
||
return self._diff_tables(table1, table2) | ||
|
||
def _diff_tables(self, table1, table2, level=0): | ||
count1 = table1.count | ||
count2 = table2.count | ||
|
||
# If count is below the threshold, just download and compare the columns locally | ||
# This saves time, as bisection speed is limited by ping and query performance. | ||
if count1 < self.bisection_threshold and count2 < self.bisection_threshold: | ||
rows1 = table1.get_values() | ||
rows2 = table2.get_values() | ||
diff = list(diff_sets(rows1, rows2)) | ||
logger.info(". " * level + f"Diff found {len(diff)} different rows.") | ||
yield from diff | ||
return | ||
|
||
# Find mutual checkpoints between the two tables | ||
checkpoints = table1.choose_checkpoints(self.bisection_factor - 1) | ||
return self._diff_tables(table1, table2, self.bisection_factor) | ||
|
||
def set_initial_start_key_and_end_key(self, table1: TableSegment, table2: TableSegment): | ||
"""For the initial, full table segment we need to set the boundaries of | ||
the minimum and maximum key.""" | ||
|
||
table1_start_key, table1_end_key = table1.query_start_key_and_end_key() | ||
table2_start_key, table2_end_key = table2.query_start_key_and_end_key() | ||
|
||
table1.start_key = min(table1_start_key, table2_start_key) | ||
table2.start_key = table1.start_key | ||
# The +1 in the end key is to make sure that last row is encapsulated in | ||
# the final range. Because every range query assumes < end_key, not <=. | ||
table1.end_key = max(table1_end_key, table2_end_key) | ||
table2.end_key = table1.end_key | ||
|
||
assert table1.start_key <= table1.end_key | ||
|
||
def _diff_tables(self, table1, table2, bisection_factor, level=0): | ||
if level > 50: | ||
raise Exception("Recursing too deep; likely bug for infinite recursion") | ||
|
||
# This is the upper bound, but it might be smaller if there are gaps. | ||
# E.g. between id 1..10, id 5 might have been hard deleted. | ||
keyspace_size = table1.end_key - table1.start_key | ||
|
||
# We only check beyond level > 0, because otherwise we might scan the | ||
# entire index with COUNT(*). For large tables with billions of rows, we | ||
# need to split the COUNT(*) by the `bisection_factor`. | ||
if level > 0 or keyspace_size < self.bisection_threshold: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably get cleaned up into its own function now |
||
# In case the first segment is below the threshold | ||
# This is we get the count | ||
table1.compute_checksum_and_count() | ||
table2.compute_checksum_and_count() | ||
count1 = table1.count # These have been precomputed with the checksum | ||
count2 = table2.count | ||
|
||
# If count is below the threshold, just download and compare the columns locally | ||
# This saves time, as bisection speed is limited by ping and query performance. | ||
if count1 < self.bisection_threshold and count2 < self.bisection_threshold: | ||
rows1 = table1.get_values() | ||
rows2 = table2.get_values() | ||
diff = list(diff_sets(rows1, rows2)) | ||
logger.info(". " * level + f"Diff found {len(diff)} different rows.") | ||
yield from diff | ||
return | ||
|
||
# Find checkpoints between the two tables, e.g. [250, 500, 750] for a | ||
# table with 1000 ids and a bisection factor of 4. | ||
checkpoints = table1.choose_checkpoints(bisection_factor) | ||
assert checkpoints | ||
mutual_checkpoints = table2.find_checkpoints([Value(c) for c in checkpoints]) | ||
mutual_checkpoints = list(set(mutual_checkpoints)) # Duplicate values are a problem! | ||
logger.debug(". " * level + f"Found {len(mutual_checkpoints)} mutual checkpoints (out of {len(checkpoints)}).") | ||
if not mutual_checkpoints: | ||
raise Exception("Tables are too different.") | ||
|
||
# Create new instances of TableSegment between each checkpoint | ||
segmented1 = table1.segment_by_checkpoints(mutual_checkpoints) | ||
segmented2 = table2.segment_by_checkpoints(mutual_checkpoints) | ||
if self.debug: | ||
logger.debug("Performing sanity tests for chosen segments (assert sum of fragments == whole)") | ||
assert count1 == sum(s.count for s in segmented1) | ||
assert count2 == sum(s.count for s in segmented2) | ||
# [1..249, 250..499, 500..749, 750..1000] | ||
segmented1 = table1.segment_by_checkpoints(checkpoints) | ||
segmented2 = table2.segment_by_checkpoints(checkpoints) | ||
|
||
# Compare each pair of corresponding segments between table1 and table2 | ||
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)): | ||
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} of size {t1.count} and {t2.count}") | ||
n_keys = t1.end_key - t1.start_key | ||
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} keys={t1.start_key}..{t1.end_key-1} n_keys={n_keys}") | ||
t1.compute_checksum_and_count() | ||
t2.compute_checksum_and_count() | ||
|
||
if t1.checksum != t2.checksum: | ||
# Apply recursively | ||
yield from self._diff_tables(t1, t2, level + 1) | ||
yield from self._diff_tables(t1, t2, bisection_factor, level + 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
"""Provides classes for a pseudo-SQL AST that compiles to SQL code | ||
""" | ||
|
||
from typing import List, Union, Tuple, Optional | ||
from datetime import datetime | ||
|
||
|
@@ -66,6 +65,7 @@ class Select(Sql): | |
columns: List[SqlOrStr] | ||
table: SqlOrStr = None | ||
where: List[SqlOrStr] = None | ||
where_or: List[SqlOrStr] = None | ||
order_by: List[SqlOrStr] = None | ||
group_by: List[SqlOrStr] = None | ||
|
||
|
@@ -80,6 +80,12 @@ def compile(self, parent_c: Compiler): | |
if self.where: | ||
select += " WHERE " + " AND ".join(map(c.compile, self.where)) | ||
|
||
if self.where_or: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah this is a relic left over from the first iteration where I found checkpoints still with this method, will remove |
||
if self.where: | ||
select += " AND " # The OR conditions after AND | ||
else: | ||
select += " WHERE " | ||
|
||
if self.group_by: | ||
select += " GROUP BY " + ", ".join(map(c.compile, self.group_by)) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.