Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

issue 425 parse and use threads #435

Merged
merged 1 commit into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DiffVars:
primary_keys: List[str]
datasource_id: str
connection: Dict[str, str]
threads: Optional[int]


def dbt_diff(
Expand Down Expand Up @@ -110,16 +111,16 @@ def _get_diff_vars(
dev_qualified_list = [dev_database, dev_schema, model.alias]
prod_qualified_list = [prod_database, prod_schema, model.alias]

return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)


def _local_diff(diff_vars: DiffVars) -> None:
column_diffs_str = ""
dev_qualified_string = ".".join(diff_vars.dev_path)
prod_qualified_string = ".".join(diff_vars.prod_path)

table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)

table1_columns = list(table1.get_schema())
try:
Expand Down Expand Up @@ -260,6 +261,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
self.connection = None
self.project_dict = None
self.requires_upper = False
self.threads = None

self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()

Expand Down Expand Up @@ -345,6 +347,7 @@ def set_connection(self):
"role": rendered_credentials.get("role"),
"schema": rendered_credentials.get("schema"),
}
self.threads = rendered_credentials.get("threads")
self.requires_upper = True
elif conn_type == "bigquery":
method = rendered_credentials.get("method")
Expand All @@ -357,6 +360,7 @@ def set_connection(self):
"project": rendered_credentials.get("project"),
"dataset": rendered_credentials.get("dataset"),
}
self.threads = rendered_credentials.get("threads")
elif conn_type == "duckdb":
conn_info = {
"driver": conn_type,
Expand All @@ -373,6 +377,7 @@ def set_connection(self):
"port": rendered_credentials.get("port"),
"dbname": rendered_credentials.get("dbname"),
}
self.threads = rendered_credentials.get("threads")
elif conn_type == "databricks":
conn_info = {
"driver": conn_type,
Expand All @@ -382,6 +387,7 @@ def set_connection(self):
"schema": rendered_credentials.get("schema"),
"access_token": rendered_credentials.get("token"),
}
self.threads = rendered_credentials.get("threads")
elif conn_type == "postgres":
conn_info = {
"driver": "postgresql",
Expand All @@ -391,6 +397,7 @@ def set_connection(self):
"port": rendered_credentials.get("port"),
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
}
self.threads = rendered_credentials.get("threads")
else:
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")

Expand Down
32 changes: 16 additions & 16 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_local_diff(self, mock_diff_tables):
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_keys = ["key"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
_local_diff(diff_vars)

Expand All @@ -368,8 +368,8 @@ def test_local_diff(self, mock_diff_tables):
)
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
self.assertEqual(mock_connect.call_count, 2)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
mock_diff.get_stats_string.assert_called_once()

@patch("data_diff.dbt.diff_tables")
Expand All @@ -386,7 +386,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_keys = ["primary_key_column"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
_local_diff(diff_vars)

Expand All @@ -395,8 +395,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
)
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
self.assertEqual(mock_connect.call_count, 2)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
mock_diff.get_stats_string.assert_not_called()

@patch("data_diff.dbt.rich.print")
Expand All @@ -413,7 +413,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
expected_datasource_id = 1
expected_primary_keys = ["primary_key_column"]
diff_vars = DiffVars(
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None, None
)
_cloud_diff(diff_vars)

Expand Down Expand Up @@ -443,7 +443,7 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_datasource_id = None
primary_keys = ["primary_key_column"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
with self.assertRaises(ValueError):
_cloud_diff(diff_vars)

Expand All @@ -463,7 +463,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_datasource_id = 1
primary_keys = ["primary_key_column"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
with self.assertRaises(ValueError):
_cloud_diff(diff_vars)

Expand All @@ -487,7 +487,7 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_
mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=True)
mock_dbt_parser_inst.get_models.assert_called_once()
Expand All @@ -514,7 +514,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
}
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)

Expand Down Expand Up @@ -542,7 +542,7 @@ def test_diff_no_prod_configs(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
with self.assertRaises(ValueError):
dbt_diff(is_cloud=False)
Expand Down Expand Up @@ -570,7 +570,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
}
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)

Expand Down Expand Up @@ -599,7 +599,7 @@ def test_diff_only_prod_schema(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
with self.assertRaises(ValueError):
dbt_diff(is_cloud=False)
Expand Down Expand Up @@ -631,7 +631,7 @@ def test_diff_is_cloud_no_pks(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=True)

Expand Down Expand Up @@ -662,7 +662,7 @@ def test_diff_not_is_cloud_no_pks(
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict

expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)
mock_dbt_parser_inst.get_models.assert_called_once()
Expand Down