Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Fixed support for diffing columns of different names #230

Merged
merged 2 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def cursor(self, cursor_factory=None):
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
nullable_prefix = "Nullable("
if type_repr.startswith(nullable_prefix):
type_repr = type_repr[len(nullable_prefix):].rstrip(")")
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")

if type_repr.startswith("Decimal"):
type_repr = "Decimal"
Expand All @@ -91,7 +91,7 @@ def to_string(self, s: str) -> str:
return f"toString({s})"

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
prec= coltype.precision
prec = coltype.precision
if coltype.rounds:
timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)"
return self.to_string(timestamp)
Expand Down
28 changes: 14 additions & 14 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,42 +172,42 @@ def _parse_key_range_result(self, key_type, key_range):
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e

def _validate_and_adjust_columns(self, table1, table2):
for c in table1._relevant_columns:
if c not in table1._schema:
for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns):
if c1 not in table1._schema:
raise ValueError(f"Column '{c}' not found in schema for table {table1}")
if c not in table2._schema:
if c2 not in table2._schema:
raise ValueError(f"Column '{c}' not found in schema for table {table2}")

# Update schemas to minimal mutual precision
col1 = table1._schema[c]
col2 = table2._schema[c]
col1 = table1._schema[c1]
col2 = table2._schema[c2]
if isinstance(col1, PrecisionType):
if not isinstance(col2, PrecisionType):
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")

table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)

elif isinstance(col1, NumericType):
if not isinstance(col2, NumericType):
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")

table1._schema[c] = col1.replace(precision=lowest.precision)
table2._schema[c] = col2.replace(precision=lowest.precision)
table1._schema[c1] = col1.replace(precision=lowest.precision)
table2._schema[c2] = col2.replace(precision=lowest.precision)

elif isinstance(col1, StringType):
if not isinstance(col2, StringType):
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

for t in [table1, table2]:
for c in t._relevant_columns:
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TEST_ORACLE_CONN_STRING: str = None
TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI")
TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") or None

DEFAULT_N_SAMPLES = 50
Expand Down
10 changes: 5 additions & 5 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def init_conns():
],
"uuid": [
"String",
]
}
],
},
}


Expand Down Expand Up @@ -482,13 +482,13 @@ def _insert_to_table(conn, table, values, type):
if type.startswith("DateTime64"):
value = f"'{sample.replace(tzinfo=None)}'"

elif type == 'DateTime':
elif type == "DateTime":
sample = sample.replace(tzinfo=None)
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
value = f"'{str(sample)[:19]}'"

elif type.startswith('Decimal'):
precision = int(type[8:].rstrip(')').split(',')[1])
elif type.startswith("Decimal"):
precision = int(type[8:].rstrip(")").split(",")[1])
value = round(sample, precision)

else:
Expand Down
74 changes: 55 additions & 19 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,25 +234,7 @@ def setUp(self):
f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)",
None,
)
# self.preql(
# f"""
# table {self.table_src_name} {{
# userid: int
# movieid: int
# rating: float
# timestamp: timestamp
# }}

# table {self.table_dst_name} {{
# userid: int
# movieid: int
# rating: float
# timestamp: timestamp
# }}
# commit()
# """
# )
self.preql.commit()
_commit(self.connection)

self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False)
Expand Down Expand Up @@ -402,6 +384,60 @@ def test_diff_sorted_by_key(self):
self.assertEqual(expected, diff)


@test_per_database
class TestDiffTables2(TestPerDatabase):
def test_diff_column_names(self):
float_type = _get_float_type(self.connection)

self.connection.query(
f"create table {self.table_src}(id int, rating {float_type}, timestamp timestamp)",
None,
)
self.connection.query(
f"create table {self.table_dst}(id2 int, rating2 {float_type}, timestamp2 timestamp)",
None,
)
_commit(self.connection)

time = "2022-01-01 00:00:00"
time2 = "2021-01-01 00:00:00"

time_str = f"timestamp '{time}'"
time_str2 = f"timestamp '{time2}'"
_insert_rows(
self.connection,
self.table_src,
["id", "rating", "timestamp"],
[
[1, 9, time_str],
[2, 9, time_str2],
[3, 9, time_str],
[4, 9, time_str2],
[5, 9, time_str],
],
)

_insert_rows(
self.connection,
self.table_dst,
["id2", "rating2", "timestamp2"],
[
[1, 9, time_str],
[2, 9, time_str2],
[3, 9, time_str],
[4, 9, time_str2],
[5, 9, time_str],
],
)

table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False)

differ = TableDiffer()
diff = list(differ.diff_tables(table1, table2))
assert diff == []


@test_per_database
class TestUUIDs(TestPerDatabase):
def setUp(self):
Expand Down