Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c6ab394

Browse files
committed
override --cloud host name
1 parent f406c0a commit c6ab394

File tree

3 files changed

+84
-7
lines changed

3 files changed

+84
-7
lines changed

data_diff/__main__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
228228
metavar="PATH",
229229
help="Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents.",
230230
)
231+
@click.option(
232+
"--cloud-host-name",
233+
envvar="DATAFOLD_API_HOST_NAME",
234+
default=None,
235+
help="Override the host name for on-premise Datafold deployments --cloud api calls. Can also be set via the DATAFOLD_API_HOST_NAME environment variable.",
236+
)
231237
def main(conf, run, **kw):
232238
if kw["table2"] is None and kw["database2"]:
233239
# Use the "database table table" form
@@ -264,6 +270,7 @@ def main(conf, run, **kw):
264270
profiles_dir_override=kw["dbt_profiles_dir"],
265271
project_dir_override=kw["dbt_project_dir"],
266272
is_cloud=kw["cloud"],
273+
cloud_host_name=kw["cloud_host_name"],
267274
)
268275
else:
269276
return _data_diff(**kw)
@@ -306,6 +313,7 @@ def _data_diff(
306313
cloud,
307314
dbt_profiles_dir,
308315
dbt_project_dir,
316+
cloud_host_name,
309317
threads1=None,
310318
threads2=None,
311319
__conf__=None,

data_diff/dbt.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ class DiffVars:
7373

7474

7575
def dbt_diff(
76-
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
76+
profiles_dir_override: Optional[str] = None,
77+
project_dir_override: Optional[str] = None,
78+
is_cloud: bool = False,
79+
cloud_host_name: Optional[str] = None,
7780
) -> None:
7881
set_entrypoint_name("CLI-dbt")
7982
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
@@ -101,7 +104,7 @@ def dbt_diff(
101104
)
102105

103106
if is_cloud and len(diff_vars.primary_keys) > 0:
104-
_cloud_diff(diff_vars)
107+
_cloud_diff(diff_vars, cloud_host_name)
105108
elif not is_cloud and len(diff_vars.primary_keys) > 0:
106109
_local_diff(diff_vars)
107110
else:
@@ -219,8 +222,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
219222
)
220223

221224

222-
def _cloud_diff(diff_vars: DiffVars) -> None:
225+
def _cloud_diff(diff_vars: DiffVars, host_name: Optional[str] = None) -> None:
223226
api_key = os.environ.get("DATAFOLD_API_KEY")
227+
host_name = host_name or "app.datafold.com"
224228

225229
if diff_vars.datasource_id is None:
226230
raise ValueError(
@@ -229,7 +233,7 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
229233
if api_key is None:
230234
raise ValueError("API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
231235

232-
url = "https://app.datafold.com/api/v1/datadiffs"
236+
url = f"https://{host_name}/api/v1/datadiffs"
233237

234238
payload = {
235239
"data_source1_id": diff_vars.datasource_id,
@@ -255,8 +259,8 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
255259
response.raise_for_status()
256260
data = response.json()
257261
diff_id = data["id"]
258-
# TODO in future we should support self hosted datafold
259-
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
262+
263+
diff_url = f"https://{host_name}/datadiffs/{diff_id}/overview"
260264
rich.print(
261265
"[red]"
262266
+ ".".join(diff_vars.prod_path)

tests/test_dbt.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,48 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
426426

427427
mock_request.assert_called_once()
428428
mock_print.assert_called_once()
429+
request_endpoint = mock_request.call_args[0][1]
429430
request_data_dict = mock_request.call_args[1]["json"]
430431
self.assertEqual(
431432
mock_request.call_args[1]["headers"]["Authorization"],
432433
"Key " + expected_api_key,
433434
)
435+
self.assertEqual(request_endpoint, f'https://app.datafold.com/api/v1/datadiffs')
436+
self.assertEqual(request_data_dict["data_source1_id"], expected_datasource_id)
437+
self.assertEqual(request_data_dict["data_source2_id"], expected_datasource_id)
438+
self.assertEqual(request_data_dict["table1"], prod_qualified_list)
439+
self.assertEqual(request_data_dict["table2"], dev_qualified_list)
440+
self.assertEqual(request_data_dict["pk_columns"], expected_primary_keys)
441+
442+
@patch("data_diff.dbt.rich.print")
443+
@patch("data_diff.dbt.os.environ")
444+
@patch("data_diff.dbt.requests.request")
445+
def test_cloud_diff_host_name_override(self, mock_request, mock_os_environ, mock_print):
446+
expected_api_key = "an_api_key"
447+
mock_response = Mock()
448+
mock_response.json.return_value = {"id": 123}
449+
mock_request.return_value = mock_response
450+
mock_os_environ.get.return_value = expected_api_key
451+
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
452+
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
453+
expected_datasource_id = 1
454+
expected_primary_keys = ["primary_key_column"]
455+
diff_vars = DiffVars(
456+
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None, None
457+
)
458+
host_name = "a_host_name"
459+
_cloud_diff(diff_vars, host_name)
460+
461+
mock_request.assert_called_once()
462+
mock_print.assert_called_once()
463+
464+
request_endpoint = mock_request.call_args[0][1]
465+
request_data_dict = mock_request.call_args[1]["json"]
466+
self.assertEqual(
467+
mock_request.call_args[1]["headers"]["Authorization"],
468+
"Key " + expected_api_key,
469+
)
470+
self.assertEqual(request_endpoint, f'https://{host_name}/api/v1/datadiffs')
434471
self.assertEqual(request_data_dict["data_source1_id"], expected_datasource_id)
435472
self.assertEqual(request_data_dict["data_source2_id"], expected_datasource_id)
436473
self.assertEqual(request_data_dict["table1"], prod_qualified_list)
@@ -500,7 +537,35 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_
500537
mock_dbt_parser_inst.get_models.assert_called_once()
501538
mock_dbt_parser_inst.set_connection.assert_not_called()
502539

503-
mock_cloud_diff.assert_called_once_with(expected_diff_vars)
540+
mock_cloud_diff.assert_called_once_with(expected_diff_vars, None)
541+
mock_local_diff.assert_not_called()
542+
mock_print.assert_called_once()
543+
544+
@patch("data_diff.dbt._get_diff_vars")
545+
@patch("data_diff.dbt._local_diff")
546+
@patch("data_diff.dbt._cloud_diff")
547+
@patch("data_diff.dbt.DbtParser.__new__")
548+
@patch("data_diff.dbt.rich.print")
549+
def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars):
550+
mock_dbt_parser_inst = Mock()
551+
mock_model = Mock()
552+
expected_dbt_vars_dict = {
553+
"prod_database": "prod_db",
554+
"prod_schema": "prod_schema",
555+
"datasource_id": 1,
556+
}
557+
host_name = 'a_host_name'
558+
559+
mock_dbt_parser.return_value = mock_dbt_parser_inst
560+
mock_dbt_parser_inst.get_models.return_value = [mock_model]
561+
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
562+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
563+
mock_get_diff_vars.return_value = expected_diff_vars
564+
dbt_diff(is_cloud=True, cloud_host_name=host_name)
565+
mock_dbt_parser_inst.get_models.assert_called_once()
566+
mock_dbt_parser_inst.set_connection.assert_not_called()
567+
568+
mock_cloud_diff.assert_called_once_with(expected_diff_vars, host_name)
504569
mock_local_diff.assert_not_called()
505570
mock_print.assert_called_once()
506571

0 commit comments

Comments
 (0)