Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Make InfoTree classes overrideable #824

Merged
merged 2 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from contextlib import contextmanager
from operator import methodcaller
from typing import Dict, Set, List, Tuple, Iterator, Optional
from typing import Dict, Set, List, Tuple, Iterator, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed

import attrs
Expand Down Expand Up @@ -182,6 +182,8 @@ def get_stats_dict(self, is_dbt: bool = False):

@attrs.define(frozen=False)
class TableDiffer(ThreadBase, ABC):
INFO_TREE_CLASS = InfoTree

bisection_factor = 32
stats: dict = {}

Expand All @@ -204,7 +206,8 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
Where `row` is a tuple of values, corresponding to the diffed columns.
"""
if info_tree is None:
info_tree = InfoTree(SegmentInfo([table1, table2]))
segment_info = self.INFO_TREE_CLASS.SEGMENT_INFO_CLASS([table1, table2])
info_tree = self.INFO_TREE_CLASS(segment_info)
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree, self.stats)

def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
Expand Down Expand Up @@ -259,7 +262,7 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen

def _diff_tables_root(
self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree
) -> DiffResult | DiffResultList:
) -> Union[DiffResult, DiffResultList]:
return self._bisect_and_diff_tables(table1, table2, info_tree)

@abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions data_diff/info_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Dict, Optional, Any, Tuple, Union

import attrs
from typing_extensions import Self

from data_diff.table_segment import TableSegment

Expand Down Expand Up @@ -41,11 +42,14 @@ def update_from_children(self, child_infos):

@attrs.define(frozen=True)
class InfoTree:
SEGMENT_INFO_CLASS = SegmentInfo

info: SegmentInfo
children: List["InfoTree"] = attrs.field(factory=list)

def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: int = None):
node = InfoTree(SegmentInfo([table1, table2], max_rows=max_rows))
def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optional[int] = None) -> Self:
cls = self.__class__
node = cls(cls.SEGMENT_INFO_CLASS([table1, table2], max_rows=max_rows))
self.children.append(node)
return node

Expand Down