Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

expand --cloud output by polling for results #467

Merged
merged 4 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 147 additions & 59 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from packaging.version import parse as parse_version
from typing import List, Optional, Dict, Tuple, Set
from .utils import getLogger
from .utils import dbt_diff_string_template, getLogger
from .version import __version__
from pathlib import Path

Expand Down Expand Up @@ -69,16 +69,16 @@ class DiffVars:
dev_path: List[str]
prod_path: List[str]
primary_keys: List[str]
datasource_id: str
connection: Dict[str, str]
threads: Optional[int]


def dbt_diff(
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
) -> None:
diff_threads = []
set_entrypoint_name("CLI-dbt")
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
models = dbt_parser.get_models()
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
Expand All @@ -89,7 +89,17 @@ def dbt_diff(
custom_schemas = True if custom_schemas is None else custom_schemas
set_dbt_user_id(dbt_parser.dbt_user_id)

if not is_cloud:
if is_cloud:
if datasource_id is None:
raise ValueError(
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
)
datafold_host, url, api_key = _setup_cloud_diff()

# exit so the user can set the key
if not api_key:
return
else:
dbt_parser.set_connection()

if config_prod_database is None:
Expand All @@ -98,14 +108,14 @@ def dbt_diff(
)

for model in models:
diff_vars = _get_diff_vars(
dbt_parser, config_prod_database, config_prod_schema, model, datasource_id, custom_schemas
)

if is_cloud and len(diff_vars.primary_keys) > 0:
_cloud_diff(diff_vars)
elif not is_cloud and len(diff_vars.primary_keys) > 0:
_local_diff(diff_vars)
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, custom_schemas)

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, datafold_host, url, api_key)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars)
else:
rich.print(
"[red]"
Expand All @@ -116,6 +126,11 @@ def dbt_diff(
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

# wait for all threads
if diff_threads:
for thread in diff_threads:
thread.join()

rich.print("Diffs Complete!")


Expand All @@ -124,7 +139,6 @@ def _get_diff_vars(
config_prod_database: Optional[str],
config_prod_schema: Optional[str],
model,
datasource_id: int,
custom_schemas: bool,
) -> DiffVars:
dev_database = model.database
Expand All @@ -149,9 +163,7 @@ def _get_diff_vars(
dev_qualified_list = [dev_database, dev_schema, model.alias]
prod_qualified_list = [prod_database, prod_schema, model.alias]

return DiffVars(
dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads
)
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads)


def _local_diff(diff_vars: DiffVars) -> None:
Expand Down Expand Up @@ -221,33 +233,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
)


def _cloud_diff(diff_vars: DiffVars) -> None:
datafold_host = os.environ.get("DATAFOLD_HOST")
if datafold_host is None:
datafold_host = "https://app.datafold.com"
datafold_host = datafold_host.rstrip("/")
rich.print(f"Cloud datafold host: {datafold_host}")

api_key = os.environ.get("DATAFOLD_API_KEY")
if not api_key:
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
if yes_or_no:
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
return
else:
raise ValueError("Cannot diff because the API key is not provided")

if diff_vars.datasource_id is None:
raise ValueError(
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
)

url = f"{datafold_host}/api/v1/datadiffs"

def _cloud_diff(diff_vars: DiffVars, datasource_id: int, datafold_host: str, url: str, api_key: str) -> None:
payload = {
"data_source1_id": diff_vars.datasource_id,
"data_source2_id": diff_vars.datasource_id,
"data_source1_id": datasource_id,
"data_source2_id": datasource_id,
"table1": diff_vars.prod_path,
"table2": diff_vars.dev_path,
"pk_columns": diff_vars.primary_keys,
Expand All @@ -258,27 +247,60 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
"Content-Type": "application/json",
}
if is_tracking_enabled():
event_json = create_start_event_json({"is_cloud": True, "datasource_id": diff_vars.datasource_id})
event_json = create_start_event_json({"is_cloud": True, "datasource_id": datasource_id})
run_as_daemon(send_event_json, event_json)

start = time.monotonic()
error = None
diff_id = None
diff_url = None
try:
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
diff_id = data["id"]
diff_id = _cloud_submit_diff(url, payload, headers)
summary_url = f"{url}/{diff_id}/summary_results"
diff_results = _cloud_poll_and_get_summary_results(summary_url, headers)

diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ "[/] \n Diff in progress: \n "
+ diff_url
+ "\n"
)

rows_added_count = diff_results["pks"]["exclusives"][1]
rows_removed_count = diff_results["pks"]["exclusives"][0]

rows_updated = diff_results["values"]["rows_with_differences"]
total_rows = diff_results["values"]["total_rows"]
rows_unchanged = int(total_rows) - int(rows_updated)
diff_percent_list = {
x["column_name"]: str(x["match"]) + "%"
for x in diff_results["values"]["columns_diff_stats"]
if x["match"] != 100.0
}

if any([rows_added_count, rows_removed_count, rows_updated]):
diff_output = dbt_diff_string_template(
rows_added_count,
rows_removed_count,
rows_updated,
str(rows_unchanged),
diff_percent_list,
"Value Match Percent:",
)
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ f"[/]\n{diff_url}\n"
+ diff_output
+ "\n"
)
else:
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ f"[/]\n{diff_url}\n"
+ "[green]No row differences[/] \n"
)

except BaseException as ex: # Catch KeyboardInterrupt too
error = ex
finally:
Expand All @@ -302,15 +324,81 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
send_event_json(event_json)

if error:
raise error
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path) + "[/]\n"
)
if diff_id:
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
rich.print(f"{diff_url} \n")
logger.error(error)


def _setup_cloud_diff() -> Tuple[Optional[str]]:
datafold_host = os.environ.get("DATAFOLD_HOST")
if datafold_host is None:
datafold_host = "https://app.datafold.com"
datafold_host = datafold_host.rstrip("/")
rich.print(f"Cloud datafold host: {datafold_host}\n")
url = f"{datafold_host}/api/v1/datadiffs"

api_key = os.environ.get("DATAFOLD_API_KEY")
if not api_key:
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
if yes_or_no:
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
return None, None, None
else:
raise ValueError("Cannot diff because the API key is not provided")

return datafold_host, url, api_key


def _cloud_submit_diff(url, payload, headers) -> str:
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
response_json = response.json()
diff_id = str(response_json["id"])

if diff_id is None:
raise Exception(f"Api response did not contain a diff_id: {str(response_json)}")
return diff_id


def _cloud_poll_and_get_summary_results(url, headers):
summary_results = None
start_time = time.time()
sleep_interval = 5 # starts at 5 sec
max_sleep_interval = 60
max_wait_time = 300

while not summary_results:
response = requests.request("GET", url, headers=headers, timeout=30)
response.raise_for_status()
response_json = response.json()

if response_json["status"] == "success":
summary_results = response_json
elif response_json["status"] == "failed":
raise Exception(f"Diff failed: {str(response_json)}")

if time.time() - start_time > max_wait_time:
raise Exception("Timed out waiting for diff results")

time.sleep(sleep_interval)
sleep_interval = min(sleep_interval * 2, max_sleep_interval)

return summary_results


class DbtParser:
def __init__(self, profiles_dir_override: str, project_dir_override: str, is_cloud: bool) -> None:
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
self.project_dir = Path(project_dir_override or default_project_dir())
self.is_cloud = is_cloud
self.connection = None
self.project_dict = self.get_project_dict()
self.manifest_obj = self.get_manifest_obj()
Expand Down
22 changes: 9 additions & 13 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from data_diff.info_tree import InfoTree, SegmentInfo

from .utils import run_as_daemon, safezip, getLogger, truncate_error, Vector
from .utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector
from .thread_utils import ThreadedYielder
from .table_segment import TableSegment, create_mesh_from_points
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
Expand Down Expand Up @@ -139,18 +139,14 @@ def get_stats_string(self, is_dbt: bool = False):
diff_stats = self._get_stats(is_dbt)

if is_dbt:
string_output = "\n| Rows Added\t| Rows Removed\n"
string_output += "------------------------------------------------------------\n"

string_output += f"| {diff_stats.diff_by_sign['-']}\t\t| {diff_stats.diff_by_sign['+']}\n"
string_output += "------------------------------------------------------------\n\n"
string_output += f"Updated Rows: {diff_stats.diff_by_sign['!']}\n"
string_output += f"Unchanged Rows: {diff_stats.unchanged}\n\n"

string_output += f"Values Updated:"

for k, v in diff_stats.extra_column_diffs.items():
string_output += f"\n{k}: {v}"
string_output = dbt_diff_string_template(
diff_stats.diff_by_sign["-"],
diff_stats.diff_by_sign["+"],
diff_stats.diff_by_sign["!"],
diff_stats.unchanged,
diff_stats.extra_column_diffs,
"Values Updated:",
)

else:
string_output = ""
Expand Down
19 changes: 19 additions & 0 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,22 @@ def __sub__(self, other: "Vector"):

def __repr__(self) -> str:
return "(%s)" % ", ".join(str(k) for k in self)


def dbt_diff_string_template(
rows_added: str, rows_removed: str, rows_updated: str, rows_unchanged: str, extra_info_dict: Dict, extra_info_str
) -> str:
string_output = "\n| Rows Added\t| Rows Removed\n"
string_output += "------------------------------------------------------------\n"

string_output += f"| {rows_added}\t\t| {rows_removed}\n"
string_output += "------------------------------------------------------------\n\n"
string_output += f"Updated Rows: {rows_updated}\n"
string_output += f"Unchanged Rows: {rows_unchanged}\n\n"

string_output += extra_info_str

for k, v in extra_info_dict.items():
string_output += f"\n{k}: {v}"

return string_output
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
data-diff = 'data_diff.__main__:main'

[tool.black]
line-length = 120
Loading