Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

render jinja in entire selected profile #395

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 44 additions & 44 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,100 +304,100 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
with open(profiles_path) as profiles:
profiles = self.yaml.safe_load(profiles)

dbt_profile = self.project_dict.get("profile")
dbt_profile_var = self.project_dict.get("profile")

profile_outputs = get_from_dict_with_raise(
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
profile = get_from_dict_with_raise(
profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'."
)
# values can contain env_vars
rendered_profile = self.ProfileRenderer().render_data(profile)
profile_target = get_from_dict_with_raise(
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'."
)
outputs = get_from_dict_with_raise(
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
rendered_profile, "outputs", f"No outputs found in profile '{dbt_profile_var}' in '{profiles_path}'."
)
credentials = get_from_dict_with_raise(
outputs,
profile_target,
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
f"No credentials found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
)
conn_type = get_from_dict_with_raise(
credentials,
"type",
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
f"No type found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
)
conn_type = conn_type.lower()

# values can contain env_vars
rendered_credentials = self.ProfileRenderer().render_data(credentials)
return rendered_credentials, conn_type
return credentials, conn_type

def set_connection(self):
rendered_credentials, conn_type = self._get_connection_creds()
credentials, conn_type = self._get_connection_creds()

if conn_type == "snowflake":
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
if credentials.get("password") is None or credentials.get("private_key_path") is not None:
raise Exception("Only password authentication is currently supported for Snowflake.")
conn_info = {
"driver": conn_type,
"user": rendered_credentials.get("user"),
"password": rendered_credentials.get("password"),
"account": rendered_credentials.get("account"),
"database": rendered_credentials.get("database"),
"warehouse": rendered_credentials.get("warehouse"),
"role": rendered_credentials.get("role"),
"schema": rendered_credentials.get("schema"),
"user": credentials.get("user"),
"password": credentials.get("password"),
"account": credentials.get("account"),
"database": credentials.get("database"),
"warehouse": credentials.get("warehouse"),
"role": credentials.get("role"),
"schema": credentials.get("schema"),
}
self.threads = rendered_credentials.get("threads")
self.threads = credentials.get("threads")
self.requires_upper = True
elif conn_type == "bigquery":
method = rendered_credentials.get("method")
method = credentials.get("method")
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
# this assumes that the user is auth'd via `gcloud auth application-default login`
if method is None or method != "oauth":
raise Exception("Oauth is the current method supported for Big Query.")
conn_info = {
"driver": conn_type,
"project": rendered_credentials.get("project"),
"dataset": rendered_credentials.get("dataset"),
"project": credentials.get("project"),
"dataset": credentials.get("dataset"),
}
self.threads = rendered_credentials.get("threads")
self.threads = credentials.get("threads")
elif conn_type == "duckdb":
conn_info = {
"driver": conn_type,
"filepath": rendered_credentials.get("path"),
"filepath": credentials.get("path"),
}
elif conn_type == "redshift":
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
if credentials.get("password") is None or credentials.get("method") == "iam":
raise Exception("Only password authentication is currently supported for Redshift.")
conn_info = {
"driver": conn_type,
"host": rendered_credentials.get("host"),
"user": rendered_credentials.get("user"),
"password": rendered_credentials.get("password"),
"port": rendered_credentials.get("port"),
"dbname": rendered_credentials.get("dbname"),
"host": credentials.get("host"),
"user": credentials.get("user"),
"password": credentials.get("password"),
"port": credentials.get("port"),
"dbname": credentials.get("dbname"),
}
self.threads = rendered_credentials.get("threads")
self.threads = credentials.get("threads")
elif conn_type == "databricks":
conn_info = {
"driver": conn_type,
"catalog": rendered_credentials.get("catalog"),
"server_hostname": rendered_credentials.get("host"),
"http_path": rendered_credentials.get("http_path"),
"schema": rendered_credentials.get("schema"),
"access_token": rendered_credentials.get("token"),
"catalog": credentials.get("catalog"),
"server_hostname": credentials.get("host"),
"http_path": credentials.get("http_path"),
"schema": credentials.get("schema"),
"access_token": credentials.get("token"),
}
self.threads = rendered_credentials.get("threads")
self.threads = credentials.get("threads")
elif conn_type == "postgres":
conn_info = {
"driver": "postgresql",
"host": rendered_credentials.get("host"),
"user": rendered_credentials.get("user"),
"password": rendered_credentials.get("password"),
"port": rendered_credentials.get("port"),
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
"host": credentials.get("host"),
"user": credentials.get("user"),
"password": credentials.get("password"),
"port": credentials.get("port"),
"dbname": credentials.get("dbname") or credentials.get("database"),
}
self.threads = rendered_credentials.get("threads")
self.threads = credentials.get("threads")
else:
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")

Expand Down
56 changes: 31 additions & 25 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,37 +211,40 @@ def test_set_connection_not_implemented(self):

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_creds_success(self, mock_open):
profile_dict = {
profiles_dict = {
"a_profile": {
"outputs": {
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
},
"target": "a_target",
}
}
expected_credentials = profile_dict["a_profile"]["outputs"]["a_target"]
profile = profiles_dict["a_profile"]
expected_credentials = profiles_dict["a_profile"]["outputs"]["a_target"]
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
mock_self.ProfileRenderer().render_data.return_value = expected_credentials
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profile
credentials, conn_type = DbtParser._get_connection_creds(mock_self)
self.assertEqual(credentials, expected_credentials)
self.assertEqual(conn_type, "type1")

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_matching_profile(self, mock_open):
profile_dict = {"a_profile": {}}
profiles_dict = {"a_profile": {}}
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "wrong_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
mock_self.yaml.safe_load.return_value = profiles_dict
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_target(self, mock_open):
profile_dict = {
profiles_dict = {
"a_profile": {
"outputs": {
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
Expand All @@ -250,8 +253,10 @@ def test_get_connection_no_target(self, mock_open):
}
mock_self = Mock()
mock_self.profiles_dir = ""
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
mock_self.yaml.safe_load.return_value = profiles_dict
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

Expand All @@ -262,24 +267,19 @@ def test_get_connection_no_target(self, mock_open):

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_outputs(self, mock_open):
profile_dict = {"a_profile": {"target": "a_target"}}
profiles_dict = {"a_profile": {"target": "a_target"}}
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
mock_self.yaml.safe_load.return_value = profiles_dict
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

profile_yaml_no_credentials = """
a_profile:
outputs:
a_target:
target: a_target
"""

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_credentials(self, mock_open):
profile_dict = {
profiles_dict = {
"a_profile": {
"outputs": {"a_target": {}},
"target": "a_target",
Expand All @@ -288,13 +288,15 @@ def test_get_connection_no_credentials(self, mock_open):
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
mock_self.yaml.safe_load.return_value = profiles_dict
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_target_credentials(self, mock_open):
profile_dict = {
profiles_dict = {
"a_profile": {
"outputs": {
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
Expand All @@ -305,13 +307,15 @@ def test_get_connection_no_target_credentials(self, mock_open):
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
mock_self.yaml.safe_load.return_value = profiles_dict
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

@patch("builtins.open", new_callable=mock_open, read_data="")
def test_get_connection_no_type(self, mock_open):
profile_dict = {
profiles_dict = {
"a_profile": {
"outputs": {"a_target": {"credential_1": "credential_1", "credential_2": "credential_2"}},
"target": "a_target",
Expand All @@ -320,7 +324,9 @@ def test_get_connection_no_type(self, mock_open):
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "a_profile"}
mock_self.yaml.safe_load.return_value = profile_dict
mock_self.yaml.safe_load.return_value = profiles_dict
profile = profiles_dict["a_profile"]
mock_self.ProfileRenderer().render_data.return_value = profile
with self.assertRaises(ValueError):
_, _ = DbtParser._get_connection_creds(mock_self)

Expand Down Expand Up @@ -366,7 +372,7 @@ def test_local_diff(self, mock_diff_tables):
mock_diff_tables.assert_called_once_with(
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
)
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
Expand All @@ -393,7 +399,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
mock_diff_tables.assert_called_once_with(
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
)
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
Expand Down