Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Ignore columns at runtime on request (e.g. with too many diffs in them) #822

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Provides classes for performing a table diff
"""

import threading
import time
from abc import ABC, abstractmethod
from enum import Enum
from contextlib import contextmanager
from operator import methodcaller
from typing import Dict, Tuple, Iterator, Optional
from typing import Dict, Set, Tuple, Iterator, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

import attrs
Expand Down Expand Up @@ -184,6 +184,10 @@ class TableDiffer(ThreadBase, ABC):
bisection_factor = 32
stats: dict = {}

ignored_columns1: Set[str] = attrs.field(factory=set)
ignored_columns2: Set[str] = attrs.field(factory=set)
_ignored_columns_lock: threading.Lock = attrs.field(factory=threading.Lock, init=False)

def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree = None) -> DiffResultWrapper:
"""Diff the given tables.

Expand Down Expand Up @@ -353,6 +357,11 @@ def _bisect_and_diff_segments(
biggest_table = max(table1, table2, key=methodcaller("approximate_size"))
checkpoints = biggest_table.choose_checkpoints(self.bisection_factor - 1)

# Get it thread-safe, to avoid segment misalignment because of bad timing.
with self._ignored_columns_lock:
table1 = attrs.evolve(table1, ignored_columns=frozenset(self.ignored_columns1))
table2 = attrs.evolve(table2, ignored_columns=frozenset(self.ignored_columns2))

# Create new instances of TableSegment between each checkpoint
segmented1 = table1.segment_by_checkpoints(checkpoints)
segmented2 = table2.segment_by_checkpoints(checkpoints)
Expand All @@ -363,3 +372,24 @@ def _bisect_and_diff_segments(
ti.submit(
self._diff_segments, ti, t1, t2, info_node, max_rows, level + 1, i + 1, len(segmented1), priority=level
)

def ignore_column(self, column_name1: str, column_name2: str) -> None:
"""
Ignore the column (by name on sides A & B) in md5s & diffs from now on.

This affects 2 places:

- The columns are not checksumed for new(!) segments.
- The columns are ignored in in-memory diffing for running segments.

The columns are never ignored in the fetched values, whether they are
the same or different — for data consistency.

Use this feature to collect relatively well-represented differences
across all columns if one of them is highly different in the beginning
of a table (as per the order of segmentation/bisection). Otherwise,
that one column might easily hit the limit and stop the whole diff.
"""
with self._ignored_columns_lock:
self.ignored_columns1.add(column_name1)
self.ignored_columns2.add(column_name2)
56 changes: 42 additions & 14 deletions data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from numbers import Number
import logging
from collections import defaultdict
from typing import Iterator
from typing import Any, Collection, Dict, Iterator, List, Sequence, Set, Tuple

import attrs
from typing_extensions import Literal

from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON
from data_diff.info_tree import InfoTree
Expand All @@ -20,25 +21,42 @@

logger = logging.getLogger("hashdiff_tables")


def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
sa = set(a)
sb = set(b)
# Just for local readability: TODO: later switch to real type declarations of these.
_Op = Literal["+", "-"]
_PK = Any
_Row = Tuple[Any]


def diff_sets(
a: Sequence[_Row],
b: Sequence[_Row],
*,
json_cols: dict = None,
columns1: Sequence[str],
columns2: Sequence[str],
ignored_columns1: Collection[str],
ignored_columns2: Collection[str],
) -> Iterator:
# Differ only by columns of interest (PKs+relevant-ignored). But yield with ignored ones!
sa: Set[_Row] = {tuple(val for col, val in safezip(columns1, row) if col not in ignored_columns1) for row in a}
sb: Set[_Row] = {tuple(val for col, val in safezip(columns2, row) if col not in ignored_columns2) for row in b}

# The first item is always the key (see TableDiffer.relevant_columns)
# TODO update when we add compound keys to hashdiff
d = defaultdict(list)
diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list)
for row in a:
if row not in sb:
d[row[0]].append(("-", row))
cutrow: _Row = tuple(val for col, val in zip(columns1, row) if col not in ignored_columns1)
if cutrow not in sb:
diffs_by_pks[row[0]].append(("-", row))
for row in b:
if row not in sa:
d[row[0]].append(("+", row))
cutrow: _Row = tuple(val for col, val in zip(columns2, row) if col not in ignored_columns2)
if cutrow not in sa:
diffs_by_pks[row[0]].append(("+", row))

warned_diff_cols = set()
for _k, v in sorted(d.items(), key=lambda i: i[0]):
for diffs in (diffs_by_pks[pk] for pk in sorted(diffs_by_pks)):
if json_cols:
parsed_match, overriden_diff_cols = diffs_are_equiv_jsons(v, json_cols)
parsed_match, overriden_diff_cols = diffs_are_equiv_jsons(diffs, json_cols)
if parsed_match:
to_warn = overriden_diff_cols - warned_diff_cols
for w in to_warn:
Expand All @@ -48,7 +66,7 @@ def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
)
warned_diff_cols.add(w)
continue
yield from v
yield from diffs


@attrs.define(frozen=False)
Expand Down Expand Up @@ -201,7 +219,17 @@ def _bisect_and_diff_segments(
for i, colname in enumerate(table1.extra_columns)
if isinstance(table1._schema[colname], JSON)
}
diff = list(diff_sets(rows1, rows2, json_cols))
diff = list(
diff_sets(
rows1,
rows2,
json_cols=json_cols,
columns1=table1.relevant_columns,
columns2=table2.relevant_columns,
ignored_columns1=self.ignored_columns1,
ignored_columns2=self.ignored_columns1,
)
)

info_tree.info.set_diff(diff)
info_tree.info.rowcounts = {1: len(rows1), 2: len(rows2)}
Expand Down
18 changes: 11 additions & 7 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List, Optional, Tuple
from typing import Container, List, Optional, Tuple
import logging
from itertools import product

Expand Down Expand Up @@ -114,6 +114,7 @@ class TableSegment:
key_columns: Tuple[str, ...]
update_column: Optional[str] = None
extra_columns: Tuple[str, ...] = ()
ignored_columns: Container[str] = frozenset()

# Restrict the segment
min_key: Optional[Vector] = None
Expand Down Expand Up @@ -179,7 +180,10 @@ def make_select(self):

def get_values(self) -> list:
"Download all the relevant values of the segment from the database"
select = self.make_select().select(*self._relevant_columns_repr)

# Fetch all the original columns, even if some were later excluded from checking.
fetched_cols = [NormalizeAsString(this[c]) for c in self.relevant_columns]
select = self.make_select().select(*fetched_cols)
return self.database.query(select, List[Tuple])

def choose_checkpoints(self, count: int) -> List[List[DbKey]]:
Expand Down Expand Up @@ -221,18 +225,18 @@ def relevant_columns(self) -> List[str]:

return list(self.key_columns) + extras

@property
def _relevant_columns_repr(self) -> List[Expr]:
return [NormalizeAsString(this[c]) for c in self.relevant_columns]

def count(self) -> int:
"""Count how many rows are in the segment, in one pass."""
return self.database.query(self.make_select().select(Count()), int)

def count_and_checksum(self) -> Tuple[int, int]:
"""Count and checksum the rows in the segment, in one pass."""

checked_columns = [c for c in self.relevant_columns if c not in self.ignored_columns]
cols = [NormalizeAsString(this[c]) for c in checked_columns]

start = time.monotonic()
q = self.make_select().select(Count(), Checksum(self._relevant_columns_repr))
q = self.make_select().select(Count(), Checksum(cols))
count, checksum = self.database.query(q, tuple)
duration = time.monotonic() - start
if duration > RECOMMENDED_CHECKSUM_DURATION:
Expand Down