Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit bf9b83f

Browse files
authored
Merge pull request #643 from datafold/LAB-27_2
add --dbt support for --columns
2 parents 765cfaf + 7c72fa0 commit bf9b83f

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

data_diff/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def main(conf, run, **kw):
317317
json_output=kw["json_output"],
318318
state=state,
319319
where_flag=kw["where"],
320+
columns_flag=kw["columns"],
320321
)
321322
else:
322323
return _data_diff(

data_diff/dbt.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def dbt_diff(
7171
state: Optional[str] = None,
7272
log_status_handler: Optional[LogStatusHandler] = None,
7373
where_flag: Optional[str] = None,
74+
columns_flag: Optional[Tuple[str]] = None,
7475
) -> None:
7576
print_version_info()
7677
diff_threads = []
@@ -110,7 +111,7 @@ def dbt_diff(
110111
if log_status_handler:
111112
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
112113

113-
diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag)
114+
diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag, columns_flag)
114115

115116
# we won't always have a prod path when using state
116117
# when the model DNE in prod manifest, skip the model diff
@@ -160,7 +161,9 @@ def _get_diff_vars(
160161
config: TDatadiffConfig,
161162
model,
162163
where_flag: Optional[str] = None,
164+
columns_flag: Optional[Tuple[str]] = None,
163165
) -> TDiffVars:
166+
cli_columns = list(columns_flag) if columns_flag else []
164167
dev_database = model.database
165168
dev_schema = model.schema_
166169

@@ -189,10 +192,10 @@ def _get_diff_vars(
189192
primary_keys=primary_keys,
190193
connection=dbt_parser.connection,
191194
threads=dbt_parser.threads,
192-
# --where takes precedence over any model level config
195+
# cli flags take precedence over any model level config
193196
where_filter=where_flag or datadiff_model_config.where_filter,
194-
include_columns=datadiff_model_config.include_columns,
195-
exclude_columns=datadiff_model_config.exclude_columns,
197+
include_columns=cli_columns or datadiff_model_config.include_columns,
198+
exclude_columns=[] if cli_columns else datadiff_model_config.exclude_columns,
196199
)
197200

198201

tests/test_dbt.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class TestDbtDiffer(unittest.TestCase):
3030
def test_integration_basic_dbt(self):
3131
artifacts_path = os.getcwd() + "/tests/dbt_artifacts"
3232
test_project_path = os.environ.get("DATA_DIFF_DBT_PROJ") or artifacts_path
33+
test_profiles_path = os.environ.get("DATA_DIFF_DBT_PROFILES") or artifacts_path
3334
diff = run_datadiff_cli(
34-
"--dbt", "--dbt-project-dir", test_project_path, "--dbt-profiles-dir", test_project_path
35+
"--dbt", "--dbt-project-dir", test_project_path, "--dbt-profiles-dir", test_profiles_path
3536
)
3637

3738
# assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
@@ -933,3 +934,37 @@ def test_get_diff_vars_call_get_prod_path_from_manifest(
933934
mock_prod_path_from_manifest.assert_called_once_with(mock_model, mock_dbt_parser.prod_manifest_obj)
934935
self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_manifest.return_value[0])
935936
self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_manifest.return_value[1])
937+
938+
@patch("data_diff.dbt._get_prod_path_from_config")
939+
@patch("data_diff.dbt._get_prod_path_from_manifest")
940+
def test_get_diff_vars_cli_columns(self, mock_prod_path_from_manifest, mock_prod_path_from_config):
941+
config = TDatadiffConfig(prod_database="prod_db")
942+
mock_model = Mock()
943+
primary_keys = ["a_primary_key"]
944+
mock_model.database = "a_dev_db"
945+
mock_model.schema_ = "a_schema"
946+
mock_model.config.schema_ = None
947+
mock_model.config.database = None
948+
mock_model.alias = "a_model_name"
949+
mock_model.unique_id = "unique_id"
950+
mock_tdatadiffmodelconfig = Mock()
951+
mock_tdatadiffmodelconfig.where_filter = "where"
952+
mock_tdatadiffmodelconfig.include_columns = ["include"]
953+
mock_tdatadiffmodelconfig.exclude_columns = ["exclude"]
954+
mock_dbt_parser = Mock()
955+
mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig
956+
mock_dbt_parser.connection = {}
957+
mock_dbt_parser.threads = 0
958+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
959+
mock_dbt_parser.requires_upper = False
960+
mock_dbt_parser.prod_manifest_obj = None
961+
mock_prod_path_from_config.return_value = ("prod_db", "prod_schema")
962+
cli_columns = ("col1", "col2")
963+
964+
diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model, where_flag=None, columns_flag=cli_columns)
965+
966+
mock_dbt_parser.get_pk_from_model.assert_called_once()
967+
mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_)
968+
mock_prod_path_from_manifest.assert_not_called()
969+
self.assertEqual(diff_vars.include_columns, list(cli_columns))
970+
self.assertEqual(diff_vars.exclude_columns, [])

0 commit comments

Comments
 (0)