@@ -30,8 +30,9 @@ class TestDbtDiffer(unittest.TestCase):
30
30
def test_integration_basic_dbt (self ):
31
31
artifacts_path = os .getcwd () + "/tests/dbt_artifacts"
32
32
test_project_path = os .environ .get ("DATA_DIFF_DBT_PROJ" ) or artifacts_path
33
+ test_profiles_path = os .environ .get ("DATA_DIFF_DBT_PROFILES" ) or artifacts_path
33
34
diff = run_datadiff_cli (
34
- "--dbt" , "--dbt-project-dir" , test_project_path , "--dbt-profiles-dir" , test_project_path
35
+ "--dbt" , "--dbt-project-dir" , test_project_path , "--dbt-profiles-dir" , test_profiles_path
35
36
)
36
37
37
38
# assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
@@ -933,3 +934,37 @@ def test_get_diff_vars_call_get_prod_path_from_manifest(
933
934
mock_prod_path_from_manifest .assert_called_once_with (mock_model , mock_dbt_parser .prod_manifest_obj )
934
935
self .assertEqual (diff_vars .prod_path [0 ], mock_prod_path_from_manifest .return_value [0 ])
935
936
self .assertEqual (diff_vars .prod_path [1 ], mock_prod_path_from_manifest .return_value [1 ])
937
+
938
+ @patch ("data_diff.dbt._get_prod_path_from_config" )
939
+ @patch ("data_diff.dbt._get_prod_path_from_manifest" )
940
+ def test_get_diff_vars_cli_columns (self , mock_prod_path_from_manifest , mock_prod_path_from_config ):
941
+ config = TDatadiffConfig (prod_database = "prod_db" )
942
+ mock_model = Mock ()
943
+ primary_keys = ["a_primary_key" ]
944
+ mock_model .database = "a_dev_db"
945
+ mock_model .schema_ = "a_schema"
946
+ mock_model .config .schema_ = None
947
+ mock_model .config .database = None
948
+ mock_model .alias = "a_model_name"
949
+ mock_model .unique_id = "unique_id"
950
+ mock_tdatadiffmodelconfig = Mock ()
951
+ mock_tdatadiffmodelconfig .where_filter = "where"
952
+ mock_tdatadiffmodelconfig .include_columns = ["include" ]
953
+ mock_tdatadiffmodelconfig .exclude_columns = ["exclude" ]
954
+ mock_dbt_parser = Mock ()
955
+ mock_dbt_parser .get_datadiff_model_config .return_value = mock_tdatadiffmodelconfig
956
+ mock_dbt_parser .connection = {}
957
+ mock_dbt_parser .threads = 0
958
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
959
+ mock_dbt_parser .requires_upper = False
960
+ mock_dbt_parser .prod_manifest_obj = None
961
+ mock_prod_path_from_config .return_value = ("prod_db" , "prod_schema" )
962
+ cli_columns = ("col1" , "col2" )
963
+
964
+ diff_vars = _get_diff_vars (mock_dbt_parser , config , mock_model , where_flag = None , columns_flag = cli_columns )
965
+
966
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
967
+ mock_prod_path_from_config .assert_called_once_with (config , mock_model , mock_model .database , mock_model .schema_ )
968
+ mock_prod_path_from_manifest .assert_not_called ()
969
+ self .assertEqual (diff_vars .include_columns , list (cli_columns ))
970
+ self .assertEqual (diff_vars .exclude_columns , [])
0 commit comments