Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

performance: divide raw keyspace into segments, avoid full index scans #32

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 165 additions & 70 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"""

from collections import defaultdict
from typing import List, Tuple
import logging
from typing import List, Tuple
import datetime

from runtype import dataclass

Expand Down Expand Up @@ -50,7 +51,8 @@ def _make_update_range(self):
if self.max_time is not None:
yield Compare("<", self.update_column, Time(self.max_time))

def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None):
def _make_select(self, *, table=None, columns=None, where=None,
group_by=None, order_by=None, where_or=None):
if columns is None:
columns = [self.key_column]
where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where])
Expand All @@ -60,6 +62,7 @@ def _make_select(self, *, table=None, columns=None, where=None, group_by=None, o
where=where,
columns=columns,
group_by=group_by,
where_or=where_or,
order_by=order_by,
)

Expand All @@ -68,27 +71,47 @@ def get_values(self) -> list:
select = self._make_select(columns=self._relevant_columns)
return self.database.query(select, List[Tuple])

def choose_checkpoints(self, count: int) -> List[DbKey]:
"Suggests a bunch of evenly-spaced checkpoints to split by"
ratio = int(self.count / count)
assert ratio > 1
skip = f"mod(idx, {ratio}) = 0"
select = self._make_select(table=Enum(self.table_path, order_by=self.key_column), where=skip)
return self.database.query(select, List[int])
def choose_checkpoints(self, bisection_factor: int) -> List[DbKey]:
"""
Choose `bisection_factor - 1` (because of start and end) checkpoints in
the keyspace.

def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]:
"Takes a list of potential checkpoints and returns those that exist"
where = In(self.key_column, checkpoints)
return self.database.query(self._make_select(where=where), List[int])
For example, a table of 1000 with `bisection_factor` of 4 would yield
the following checkpoints:
[250, 500, 750]

def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]:
"Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints"
Which would yield the following segments:
[1..249, 250..499, 500..749, 750..1000]
"""

assert self.end_key is not None
assert self.start_key is not None
assert bisection_factor >= 2
# 1..11 for bisection_factor 2 would mean gap=round(10/2)=5
# which means checkpoints returns only 1 value:
# [1 + 5 - 1] => [5]
# Then `segment_by_checkpoints` will produce:
# [1..5, 5..11]
gap = round((self.end_key - self.start_key) / (bisection_factor))
assert gap >= 1

proposed_checkpoints = [self.start_key + gap - 1]
# -2 because we add start + end in `segment_by_checkpoints`!
for i in range(bisection_factor - 2):
proposed_checkpoints.append(proposed_checkpoints[i] + gap - 1)

if self.start_key and self.end_key:
assert all(self.start_key <= c < self.end_key for c in checkpoints)
checkpoints.sort()
return proposed_checkpoints

# Calculate sub-segments
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]:
"Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints"
# Make sure start_key and end_key are set, for the beginning and end
# they may not be.
assert self.start_key is not None
assert self.end_key is not None
assert all(self.start_key <= c < self.end_key for c in checkpoints)

# Calculate sub-segments, turns checkpoints such as [250, 500, 750] into
# [1..249, 250..499, 500..749, 750..1000].
positions = [self.start_key] + checkpoints + [self.end_key]
ranges = list(zip(positions[:-1], positions[1:]))

Expand All @@ -97,37 +120,84 @@ def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment

return tables

## Calculate checksums in one go, to prevent repetitive individual calls
# selects = [t._make_select(columns=[Checksum(self._relevant_columns)]) for t in tables]
# res = self.database.query(Select(columns=selects), list)
# checksums ,= res
# assert len(checksums) == len(checkpoints) + 1
# return [t.new(_checksum=checksum) for t, checksum in safezip(tables, checksums)]

def new(self, _count=None, _checksum=None, **kwargs) -> "TableSegment":
"""Using new() creates a copy of the instance using 'replace()', and makes sure the cache is reset"""
return self.replace(_count=None, _checksum=None, **kwargs)

def __repr__(self):
return f"{type(self.database).__name__}/{', '.join(self.table_path)}"

def query_start_key_and_end_key(self) -> Tuple[int, int]:
"""Query database for minimum and maximum key. This is used for setting
the boundaries of the initial, full table table segment."""
select = self._make_select(columns=[f"min({self.key_column})", f"max({self.key_column})"])
res = self.database.query(select, Tuple)[0]

start_key = res[0] or 1
# TableSegments are always exclusive the last key:
# (1..250) => # WHERE i >= 1 AND i < 250
# Thus, for the very last segment (which is the one where these
# aren't automatically set!) -- we have to add 1.
end_key = res[1] + 1 if res[1] else 1

return (start_key, end_key)

def compute_checksum_and_count(self):
"""
Query the database for the checksum and count for this segment. Note
that it will _not_ include the `end_key` in this segment, as that's the
beginning of the next segment.
"""
if self.start_key is None or self.end_key is None:
raise ValueError("""
`start_key` and/or `end_key` are not set. Likely this is because
you didn't call `set_initial_start_key_and_end_key` to get the
min(key) and max(key) from the database for the initial, whole
table segment.
""")
if self._count is not None or self._checksum is not None:
return # already computed

# Get the count in the same index pass. Much cheaper than doing it
# separately.
select = self._make_select(columns=[Count(), Checksum(self._relevant_columns)])
result = self.database.query(select, Tuple)[0]
self._count = result[0] if result[1] else 0
self._checksum = int(result[1]) if result[1] else 0


@property
def count(self) -> int:
if self._count is None:
self._count = self.database.query(self._make_select(columns=[Count()]), int)
raise ValueError("""
You must call compute_checksum_and_count() before
accessing the count to ensure only one index scan
is performed.
""")

return self._count

@property
def _relevant_columns(self) -> List[str]:
return (
# The user may duplicate columns across -k, -t, -c so we de-dup here.
relevant = list(set(
[self.key_column]
+ ([self.update_column] if self.update_column is not None else [])
+ list(self.extra_columns)
)
))
relevant.sort()
return relevant

@property
def checksum(self) -> int:
if self._checksum is None:
self._checksum = (
self.database.query(self._make_select(columns=[Checksum(self._relevant_columns)]), int) or 0
)
# Get the count in the same index pass. Much cheaper than doing it
# separately.
select = self._make_select(columns=[Count(), Checksum(self._relevant_columns)])
result = self.database.query(select, Tuple)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably needs some error handling

self._checksum = int(result[0][1])
self._count = result[0][0]

return self._checksum


Expand All @@ -148,7 +218,6 @@ def diff_sets(a: set, b: set) -> iter:

DiffResult = iter # Iterator[Tuple[Literal["+", "-"], tuple]]


@dataclass
class TableDiffer:
"""Finds the diff between two SQL tables
Expand All @@ -160,7 +229,7 @@ class TableDiffer:
"""

bisection_factor: int = 32 # Into how many segments to bisect per iteration
bisection_threshold: int = 1024**2 # When should we stop bisecting and compare locally (in row count)
bisection_threshold: int = 10000 # When should we stop bisecting and compare locally (in row count)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 1M is way too high when we are running this against databases far away from the machine. 10k seems more sensible. We can run some benchmarks in the future, but I think this is a safer default

debug: bool = False

def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
Expand All @@ -176,49 +245,75 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
if self.bisection_factor < 2:
raise ValueError("Must have at least two segments per iteration")

self.set_initial_start_key_and_end_key(table1, table2)

logger.info(
f"Diffing tables of size {table1.count} and {table2.count} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
f"Diffing tables {repr(table1)} and {repr(table2)} | keys: {table1.start_key}...{table2.end_key} | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}."
)

if table1.checksum == table2.checksum:
return [] # No differences

return self._diff_tables(table1, table2)

def _diff_tables(self, table1, table2, level=0):
count1 = table1.count
count2 = table2.count

# If count is below the threshold, just download and compare the columns locally
# This saves time, as bisection speed is limited by ping and query performance.
if count1 < self.bisection_threshold and count2 < self.bisection_threshold:
rows1 = table1.get_values()
rows2 = table2.get_values()
diff = list(diff_sets(rows1, rows2))
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
yield from diff
return

# Find mutual checkpoints between the two tables
checkpoints = table1.choose_checkpoints(self.bisection_factor - 1)
return self._diff_tables(table1, table2, self.bisection_factor)

def set_initial_start_key_and_end_key(self, table1: TableSegment, table2: TableSegment):
"""For the initial, full table segment we need to set the boundaries of
the minimum and maximum key."""

table1_start_key, table1_end_key = table1.query_start_key_and_end_key()
table2_start_key, table2_end_key = table2.query_start_key_and_end_key()

table1.start_key = min(table1_start_key, table2_start_key)
table2.start_key = table1.start_key
# The +1 in the end key is to make sure that last row is encapsulated in
# the final range. Because every range query assumes < end_key, not <=.
table1.end_key = max(table1_end_key, table2_end_key)
table2.end_key = table1.end_key

assert table1.start_key <= table1.end_key

def _diff_tables(self, table1, table2, bisection_factor, level=0):
if level > 50:
raise Exception("Recursing too deep; likely bug for infinite recursion")

# This is the upper bound, but it might be smaller if there are gaps.
# E.g. between id 1..10, id 5 might have been hard deleted.
keyspace_size = table1.end_key - table1.start_key

# We only check beyond level > 0, because otherwise we might scan the
# entire index with COUNT(*). For large tables with billions of rows, we
# need to split the COUNT(*) by the `bisection_factor`.
if level > 0 or keyspace_size < self.bisection_threshold:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably get cleaned up into its own function now

# In case the first segment is below the threshold
# This is we get the count
table1.compute_checksum_and_count()
table2.compute_checksum_and_count()
count1 = table1.count # These have been precomputed with the checksum
count2 = table2.count

# If count is below the threshold, just download and compare the columns locally
# This saves time, as bisection speed is limited by ping and query performance.
if count1 < self.bisection_threshold and count2 < self.bisection_threshold:
rows1 = table1.get_values()
rows2 = table2.get_values()
diff = list(diff_sets(rows1, rows2))
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
yield from diff
return

# Find checkpoints between the two tables, e.g. [250, 500, 750] for a
# table with 1000 ids and a bisection factor of 4.
checkpoints = table1.choose_checkpoints(bisection_factor)
assert checkpoints
mutual_checkpoints = table2.find_checkpoints([Value(c) for c in checkpoints])
mutual_checkpoints = list(set(mutual_checkpoints)) # Duplicate values are a problem!
logger.debug(". " * level + f"Found {len(mutual_checkpoints)} mutual checkpoints (out of {len(checkpoints)}).")
if not mutual_checkpoints:
raise Exception("Tables are too different.")

# Create new instances of TableSegment between each checkpoint
segmented1 = table1.segment_by_checkpoints(mutual_checkpoints)
segmented2 = table2.segment_by_checkpoints(mutual_checkpoints)
if self.debug:
logger.debug("Performing sanity tests for chosen segments (assert sum of fragments == whole)")
assert count1 == sum(s.count for s in segmented1)
assert count2 == sum(s.count for s in segmented2)
# [1..249, 250..499, 500..749, 750..1000]
segmented1 = table1.segment_by_checkpoints(checkpoints)
segmented2 = table2.segment_by_checkpoints(checkpoints)

# Compare each pair of corresponding segments between table1 and table2
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2)):
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} of size {t1.count} and {t2.count}")
n_keys = t1.end_key - t1.start_key
logger.info(". " * level + f"Diffing segment {i+1}/{len(segmented1)} keys={t1.start_key}..{t1.end_key-1} n_keys={n_keys}")
t1.compute_checksum_and_count()
t2.compute_checksum_and_count()

if t1.checksum != t2.checksum:
# Apply recursively
yield from self._diff_tables(t1, t2, level + 1)
yield from self._diff_tables(t1, t2, bisection_factor, level + 1)
8 changes: 7 additions & 1 deletion data_diff/sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Provides classes for a pseudo-SQL AST that compiles to SQL code
"""

from typing import List, Union, Tuple, Optional
from datetime import datetime

Expand Down Expand Up @@ -66,6 +65,7 @@ class Select(Sql):
columns: List[SqlOrStr]
table: SqlOrStr = None
where: List[SqlOrStr] = None
where_or: List[SqlOrStr] = None
order_by: List[SqlOrStr] = None
group_by: List[SqlOrStr] = None

Expand All @@ -80,6 +80,12 @@ def compile(self, parent_c: Compiler):
if self.where:
select += " WHERE " + " AND ".join(map(c.compile, self.where))

if self.where_or:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah this is a relic left over from the first iteration where I found checkpoints still with this method, will remove

if self.where:
select += " AND " # The OR conditions after AND
else:
select += " WHERE "

if self.group_by:
select += " GROUP BY " + ", ".join(map(c.compile, self.group_by))

Expand Down
Loading