Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

cloud api token flow #462

Merged
merged 4 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import time
import webbrowser
import rich
from rich.prompt import Confirm

from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -220,16 +222,28 @@ def _local_diff(diff_vars: DiffVars) -> None:


def _cloud_diff(diff_vars: DiffVars) -> None:
datafold_host = os.environ.get("DATAFOLD_HOST")
if datafold_host is None:
datafold_host = "https://app.datafold.com"
datafold_host = datafold_host.rstrip("/")
rich.print(f"Cloud datafold host: {datafold_host}")

api_key = os.environ.get("DATAFOLD_API_KEY")
if not api_key:
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
if yes_or_no:
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
return
else:
raise ValueError("Cannot diff because the API key is not provided")

if diff_vars.datasource_id is None:
raise ValueError(
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
)
if api_key is None:
raise ValueError("API key not found, add it as an environment variable called DATAFOLD_API_KEY.")

url = "https://app.datafold.com/api/v1/datadiffs"
url = f"{datafold_host}/api/v1/datadiffs"

payload = {
"data_source1_id": diff_vars.datasource_id,
Expand All @@ -255,8 +269,7 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
response.raise_for_status()
data = response.json()
diff_id = data["id"]
# TODO in future we should support self hosted datafold
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
Expand Down
1 change: 0 additions & 1 deletion data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def get_stats_string(self, is_dbt: bool = False):
string_output += f"\n{k}: {v}"

else:

string_output = ""
string_output += f"{diff_stats.table1_count} rows in table A\n"
string_output += f"{diff_stats.table2_count} rows in table B\n"
Expand Down
1 change: 0 additions & 1 deletion data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def _diff_segments(
if self.materialize_to_table
else None,
):

assert len(a_cols) == len(b_cols)
logger.debug("Querying for different rows")
diff = db.query(diff_rows, list)
Expand Down
3 changes: 3 additions & 0 deletions data_diff/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def set_entrypoint_name(s):
global entrypoint_name
entrypoint_name = s


dbt_user_id = None


def set_dbt_user_id(s):
global dbt_user_id
dbt_user_id = s


def get_anonymous_id():
global g_anonymous_id
if g_anonymous_id is None:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_integration_basic_dbt(self):
artifacts_path = os.getcwd() + '/tests/dbt_artifacts'
test_project_path = os.environ.get("DATA_DIFF_DBT_PROJ") or artifacts_path
diff = run_datadiff_cli("--dbt", "--dbt-project-dir", test_project_path, "--dbt-profiles-dir", test_project_path)
assert diff[-1].decode("utf-8") == "Diffs Complete!"
assert "Diffs Complete!" in '\n'.join(d.decode("utf-8") for d in diff)

# assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
if test_project_path == artifacts_path:
Expand All @@ -340,7 +340,7 @@ def test_integration_basic_dbt(self):
assert diff_string.count('<>') == 5
# 4 with no diffs
assert diff_string.count('No row differences') == 4
# 1 with a diff
# 1 with a diff
assert diff_string.count('| Rows Added | Rows Removed') == 1


Expand Down Expand Up @@ -425,7 +425,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
_cloud_diff(diff_vars)

mock_request.assert_called_once()
mock_print.assert_called_once()
self.assertEqual(len(mock_print.call_args_list), 2)
request_data_dict = mock_request.call_args[1]["json"]
self.assertEqual(
mock_request.call_args[1]["headers"]["Authorization"],
Expand Down Expand Up @@ -455,17 +455,19 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
_cloud_diff(diff_vars)

mock_request.assert_not_called()
mock_print.assert_not_called()
mock_print.assert_called_once()

@patch("data_diff.dbt.rich.print")
@patch("data_diff.dbt.os.environ")
@patch("data_diff.dbt.requests.request")
def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print):
@patch("data_diff.dbt.Confirm.ask")
def test_cloud_diff_api_key_none(self, mock_confirm_answer, mock_request, mock_os_environ, mock_print):
expected_api_key = None
mock_response = Mock()
mock_response.json.return_value = {"id": 123}
mock_request.return_value = mock_response
mock_os_environ.get.return_value = expected_api_key
mock_confirm_answer.return_value = False
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_datasource_id = 1
Expand All @@ -475,7 +477,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
_cloud_diff(diff_vars)

mock_request.assert_not_called()
mock_print.assert_not_called()
self.assertEqual(len(mock_print.call_args_list), 2)

@patch("data_diff.dbt._get_diff_vars")
@patch("data_diff.dbt._local_diff")
Expand Down