Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6c30cdf

Browse files
authored
Merge pull request #431 from dlawin/issue_427
support combo pks in --dbt local_diff
2 parents e5f9121 + f569182 commit 6c30cdf

File tree

2 files changed

+14
-25
lines changed

2 files changed

+14
-25
lines changed

data_diff/dbt.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,16 @@ def dbt_diff(
7373

7474
if is_cloud and len(diff_vars.primary_keys) > 0:
7575
_cloud_diff(diff_vars)
76-
elif is_cloud:
77-
rich.print(
78-
"[red]"
79-
+ ".".join(diff_vars.prod_path)
80-
+ " <> "
81-
+ ".".join(diff_vars.dev_path)
82-
+ "[/] \n"
83-
+ "Skipped due to missing primary-key tag\n"
84-
)
85-
86-
if not is_cloud and len(diff_vars.primary_keys) == 1:
76+
elif not is_cloud and len(diff_vars.primary_keys) > 0:
8777
_local_diff(diff_vars)
88-
elif not is_cloud:
78+
else:
8979
rich.print(
9080
"[red]"
9181
+ ".".join(diff_vars.prod_path)
9282
+ " <> "
9383
+ ".".join(diff_vars.dev_path)
9484
+ "[/] \n"
95-
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
85+
+ "Skipped due to missing primary-key tag(s).\n"
9686
)
9787

9888
rich.print("Diffs Complete!")
@@ -127,10 +117,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
127117
column_diffs_str = ""
128118
dev_qualified_string = ".".join(diff_vars.dev_path)
129119
prod_qualified_string = ".".join(diff_vars.prod_path)
130-
primary_key = diff_vars.primary_keys[0]
131120

132-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, primary_key)
133-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, primary_key)
121+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
122+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
134123

135124
table1_columns = list(table1.get_schema())
136125
try:
@@ -159,7 +148,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
159148
if table2_set_diff:
160149
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
161150

162-
mutual_set.discard(primary_key)
151+
mutual_set = mutual_set - set(diff_vars.primary_keys)
163152
extra_columns = tuple(mutual_set)
164153

165154
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)

tests/test_dbt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,17 +358,17 @@ def test_local_diff(self, mock_diff_tables):
358358
mock_diff.__iter__.return_value = [1, 2, 3]
359359
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
360360
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
361-
expected_key = "key"
362-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, [expected_key], None, mock_connection)
361+
expected_keys = ["key"]
362+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
363363
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
364364
_local_diff(diff_vars)
365365

366366
mock_diff_tables.assert_called_once_with(
367367
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=tuple(column_set)
368368
)
369369
self.assertEqual(mock_connect.call_count, 2)
370-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), expected_key)
371-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), expected_key)
370+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
371+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
372372
mock_diff.get_stats_string.assert_called_once()
373373

374374
@patch("data_diff.dbt.diff_tables")
@@ -384,17 +384,17 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
384384
mock_diff.__iter__.return_value = []
385385
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
386386
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
387-
expected_key = "primary_key_column"
388-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, [expected_key], None, mock_connection)
387+
expected_keys = ["primary_key_column"]
388+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
389389
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
390390
_local_diff(diff_vars)
391391

392392
mock_diff_tables.assert_called_once_with(
393393
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=tuple(column_set)
394394
)
395395
self.assertEqual(mock_connect.call_count, 2)
396-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), expected_key)
397-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), expected_key)
396+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
397+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
398398
mock_diff.get_stats_string.assert_not_called()
399399

400400
@patch("data_diff.dbt.rich.print")

0 commit comments

Comments
 (0)