Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Issue 417 error handling #419

Merged
merged 4 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rich
from dataclasses import dataclass
from packaging.version import parse as parse_version
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple

import requests

Expand All @@ -28,7 +28,7 @@ def import_dbt():
send_event_json,
is_tracking_enabled,
)
from .utils import run_as_daemon, truncate_error
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
from . import connect_to_table, diff_tables, Algorithm

RUN_RESULTS_PATH = "/target/run_results.json"
Expand Down Expand Up @@ -308,18 +308,40 @@ def set_project_dict(self):
with open(self.project_dir + PROJECT_FILE) as project:
self.project_dict = self.yaml.safe_load(project)

def set_connection(self):
with open(self.profiles_dir + PROFILES_FILE) as profiles:
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
profiles_path = self.profiles_dir + PROFILES_FILE
with open(profiles_path) as profiles:
profiles = self.yaml.safe_load(profiles)

dbt_profile = self.project_dict.get("profile")
profile_outputs = profiles.get(dbt_profile)
profile_target = profile_outputs.get("target")
credentials = profile_outputs.get("outputs").get(profile_target)
conn_type = credentials.get("type").lower()

profile_outputs = get_from_dict_with_raise(
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
)
profile_target = get_from_dict_with_raise(
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
)
outputs = get_from_dict_with_raise(
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
)
credentials = get_from_dict_with_raise(
outputs,
profile_target,
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
)
conn_type = get_from_dict_with_raise(
credentials,
"type",
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
)
conn_type = conn_type.lower()

# values can contain env_vars
rendered_credentials = self.ProfileRenderer().render_data(credentials)
return rendered_credentials, conn_type

def set_connection(self):
rendered_credentials, conn_type = self._get_connection_creds()

if conn_type == "snowflake":
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
Expand Down
9 changes: 8 additions & 1 deletion data_diff/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Iterable, Sequence
from typing import Dict, Iterable, Sequence
from urllib.parse import urlparse
import operator
import threading
Expand Down Expand Up @@ -79,6 +79,13 @@ def truncate_error(error: str):
return re.sub("'(.*?)'", "'***'", first_line)


def get_from_dict_with_raise(dictionary: Dict, key: str, error_message: str):
result = dictionary.get(key)
if result is None:
raise ValueError(error_message)
return result


class Vector(tuple):

"""Immutable implementation of a regular vector over any arithmetic value
Expand Down
Loading