Skip to content

Commit c57bfed

Browse files
committed
get prod alias from manifest when provided
1 parent 765cfaf commit c57bfed

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

data_diff/dbt.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,22 +163,22 @@ def _get_diff_vars(
163163
) -> TDiffVars:
164164
dev_database = model.database
165165
dev_schema = model.schema_
166-
166+
dev_alias = prod_alias = model.alias
167167
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
168168

169169
# prod path is constructed via configuration or the prod manifest via --state
170170
if dbt_parser.prod_manifest_obj:
171-
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
171+
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
172172
else:
173173
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)
174174

175175
if dbt_parser.requires_upper:
176-
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
177-
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias] if x]
176+
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, dev_alias] if x]
177+
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, prod_alias] if x]
178178
primary_keys = [x.upper() for x in primary_keys]
179179
else:
180-
dev_qualified_list = [x for x in [dev_database, dev_schema, model.alias] if x]
181-
prod_qualified_list = [x for x in [prod_database, prod_schema, model.alias] if x]
180+
dev_qualified_list = [x for x in [dev_database, dev_schema, dev_alias] if x]
181+
prod_qualified_list = [x for x in [prod_database, prod_schema, prod_alias] if x]
182182

183183
datadiff_model_config = dbt_parser.get_datadiff_model_config(model.meta)
184184

@@ -225,14 +225,16 @@ def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple
225225
return prod_database, prod_schema
226226

227227

228-
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
228+
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str, str], Tuple[None, None, None]]:
229229
prod_database = None
230230
prod_schema = None
231+
prod_alias = None
231232
prod_model = prod_manifest.nodes.get(model.unique_id, None)
232233
if prod_model:
233234
prod_database = prod_model.database
234235
prod_schema = prod_model.schema_
235-
return prod_database, prod_schema
236+
prod_alias = prod_model.alias
237+
return prod_database, prod_schema, prod_alias
236238

237239

238240
def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:

tests/test_dbt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,9 +684,11 @@ def test_get_prod_path_from_manifest_model_exists(self):
684684
mock_prod_manifest.nodes.get.return_value = mock_prod_model
685685
mock_prod_model.database = "prod_db"
686686
mock_prod_model.schema_ = "prod_schema"
687-
prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
687+
mock_prod_model.alias = "prod_alias"
688+
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
688689
self.assertEqual(prod_database, mock_prod_model.database)
689690
self.assertEqual(prod_schema, mock_prod_model.schema_)
691+
self.assertEqual(prod_alias, mock_prod_model.alias)
690692

691693
def test_get_prod_path_from_manifest_model_not_exists(self):
692694
mock_model = Mock()
@@ -696,9 +698,11 @@ def test_get_prod_path_from_manifest_model_not_exists(self):
696698
mock_prod_manifest.nodes.get.return_value = None
697699
mock_prod_model.database = "prod_db"
698700
mock_prod_model.schema_ = "prod_schema"
699-
prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
701+
mock_prod_model.alias = "prod_alias"
702+
prod_database, prod_schema, prod_alias = _get_prod_path_from_manifest(mock_model, mock_prod_manifest)
700703
self.assertEqual(prod_database, None)
701704
self.assertEqual(prod_schema, None)
705+
self.assertEqual(prod_alias, None)
702706

703707
def test_get_diff_custom_schema_no_config_exception(self):
704708
config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema")
@@ -926,7 +930,7 @@ def test_get_diff_vars_call_get_prod_path_from_manifest(
926930
mock_dbt_parser.requires_upper = False
927931
mock_model.meta = None
928932
mock_dbt_parser.prod_manifest_obj = {"manifest_key": "manifest_value"}
929-
mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema")
933+
mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema", "prod_alias")
930934

931935
diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model)
932936

0 commit comments

Comments
 (0)