Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Make temp schema optional #509

Merged
merged 6 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion data_diff/cloud/data_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import time
from typing import List, Optional, Union, overload

Expand Down Expand Up @@ -50,14 +51,33 @@ def _validate_temp_schema(temp_schema: str):
raise ValueError("Temporary schema should have a format <database>.<schema>")


def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
diff_vars = dbt_parser.get_datadiff_variables()
config_prod_database = diff_vars.get("prod_database")
config_prod_schema = diff_vars.get("prod_schema")
if config_prod_database is not None and config_prod_schema is not None:
temp_schema = f"{config_prod_database}.{config_prod_schema}"
if db_type == "snowflake":
return temp_schema.upper()
elif db_type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
return temp_schema.lower()
return temp_schema
return


def create_ds_config(
ds_config: TCloudApiDataSourceConfigSchema,
data_source_name: str,
dbt_parser: Optional[DbtParser] = None,
) -> TDsConfig:
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)

temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
temp_schema = _get_temp_schema(dbt_parser=dbt_parser, db_type=ds_config.db_type) if dbt_parser else None
if temp_schema:
temp_schema = TemporarySchemaPrompt.ask("Temporary schema", default=temp_schema)
else:
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")

float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)

return TDsConfig(
Expand Down Expand Up @@ -92,6 +112,37 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
return value


def _get_data_from_bigquery_json(path: str):
with open(path, "r") as file:
return json.load(file)


def _align_dbt_cred_params_with_datafold_params(dbt_creds: dict) -> dict:
db_type = dbt_creds["type"]
if db_type == "bigquery":
method = dbt_creds["method"]
if method == "service-account":
data = _get_data_from_bigquery_json(path=dbt_creds["keyfile"])
dbt_creds["jsonKeyFile"] = json.dumps(data)
elif method == "service-account-json":
dbt_creds["jsonKeyFile"] = json.dumps(dbt_creds["keyfile_json"])
else:
rich.print(
f'[red]Cannot extract bigquery credentials from dbt_project.yml for "{method}" type. '
f"If you want to provide credentials via dbt_project.yml, "
f'please, use "service-account" or "service-account-json" '
f"(more in docs: https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup). "
f"Otherwise, you can provide a path to a json key file or a json key file data as an input."
)
dbt_creds["projectId"] = dbt_creds["project"]
elif db_type == "snowflake":
dbt_creds["default_db"] = dbt_creds["database"]
elif db_type == "databricks":
dbt_creds["http_password"] = dbt_creds["token"]
dbt_creds["database"] = dbt_creds.get("catalog")
return dbt_creds


def _parse_ds_credentials(
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
):
Expand All @@ -101,6 +152,7 @@ def _parse_ds_credentials(
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
try:
creds = dbt_parser.get_connection_creds()[0]
creds = _align_dbt_cred_params_with_datafold_params(dbt_creds=creds)
except Exception as e:
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")

Expand Down
13 changes: 10 additions & 3 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import dataclasses
import enum
import time
Expand Down Expand Up @@ -159,7 +160,7 @@ class TCloudDataSourceTestResult(pydantic.BaseModel):
class TCloudApiDataSourceTestResult(pydantic.BaseModel):
name: str
status: str
result: TCloudDataSourceTestResult
result: Optional[TCloudDataSourceTestResult]


@dataclasses.dataclass
Expand Down Expand Up @@ -191,7 +192,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
return [TCloudApiDataSource(**item) for item in rv.json()]

def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
rv = self.make_post_request(url="api/v1/data_sources", payload=config.dict())
payload = config.dict()
if config.type == "bigquery":
json_string = payload["options"]["jsonKeyFile"].encode("utf-8")
payload["options"]["jsonKeyFile"] = base64.b64encode(json_string).decode("utf-8")
rv = self.make_post_request(url="api/v1/data_sources", payload=payload)
return TCloudApiDataSource(**rv.json())

def get_data_source_schema_config(
Expand Down Expand Up @@ -250,7 +255,9 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
status=item["result"]["code"].lower(),
message=item["result"]["message"],
outcome=item["result"]["outcome"],
),
)
if item["result"] is not None
else None,
)
for item in rv.json()["results"]
]
2 changes: 1 addition & 1 deletion data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def set_connection(self):
"role": credentials.get("role"),
"schema": credentials.get("schema"),
"insecure_mode": credentials.get("insecure_mode", False),
"client_session_keep_alive": credentials.get("client_session_keep_alive", False)
"client_session_keep_alive": credentials.get("client_session_keep_alive", False),
}
self.threads = credentials.get("threads")
self.requires_upper = True
Expand Down
121 changes: 101 additions & 20 deletions tests/cloud/test_data_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import copy
from io import StringIO
import json
from pathlib import Path
from parameterized import parameterized
import unittest
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch

from data_diff.cloud.datafold_api import (
TCloudApiDataSourceConfigSchema,
Expand All @@ -13,20 +12,19 @@
TCloudApiDataSourceTestResult,
TCloudDataSourceTestResult,
TDsConfig,
TestDataSourceStatus,
)
from data_diff.dbt_parser import DbtParser
from data_diff.cloud.data_source import (
TDataSourceTestStage,
TestDataSourceStatus,
create_ds_config,
_check_data_source_exists,
_get_temp_schema,
_test_data_source,
)


DATA_SOURCE_CONFIGS = [
TDsConfig(
DATA_SOURCE_CONFIGS = {
"snowflake": TDsConfig(
name="ds_name",
type="snowflake",
options={
Expand All @@ -40,7 +38,7 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"pg": TDsConfig(
name="ds_name",
type="pg",
options={
Expand All @@ -53,18 +51,18 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"bigquery": TDsConfig(
name="ds_name",
type="bigquery",
options={
"projectId": "project_id",
"jsonKeyFile": "some_string",
"jsonKeyFile": '{"key1": "value1"}',
"location": "US",
},
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"databricks": TDsConfig(
name="ds_name",
type="databricks",
options={
Expand All @@ -76,7 +74,7 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"redshift": TDsConfig(
name="ds_name",
type="redshift",
options={
Expand All @@ -89,7 +87,7 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"postgres_aurora": TDsConfig(
name="ds_name",
type="postgres_aurora",
options={
Expand All @@ -102,7 +100,7 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
TDsConfig(
"postgres_aws_rds": TDsConfig(
name="ds_name",
type="postgres_aws_rds",
options={
Expand All @@ -115,7 +113,7 @@
float_tolerance=0.000001,
temp_schema="database.temp_schema",
),
]
}


def format_data_source_config_test(testcase_func, param_num, param):
Expand Down Expand Up @@ -144,7 +142,23 @@ def setUp(self) -> None:
self.api.get_data_source_schema_config.return_value = self.data_source_schema
self.api.get_data_sources.return_value = self.data_sources

@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS], name_func=format_data_source_config_test)
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
@patch("data_diff.dbt_parser.DbtParser.__new__")
def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser):
diff_vars = {
"prod_database": "db",
"prod_schema": "schema",
}
mock_dbt_parser.get_datadiff_variables.return_value = diff_vars
temp_schema = f'{diff_vars["prod_database"]}.{diff_vars["prod_schema"]}'
if config.type == "snowflake":
temp_schema = temp_schema.upper()
elif config.type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
temp_schema = temp_schema.lower()

assert _get_temp_schema(dbt_parser=mock_dbt_parser, db_type=config.type) == temp_schema

@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
def test_create_ds_config(self, config: TDsConfig):
inputs = list(config.options.values()) + [config.temp_schema, config.float_tolerance]
with patch("rich.prompt.Console.input", side_effect=map(str, inputs)):
Expand All @@ -155,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
self.assertEqual(actual_config, config)

@patch("data_diff.dbt_parser.DbtParser.__new__")
def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
config = DATA_SOURCE_CONFIGS[0]
def test_create_snowflake_ds_config_from_dbt_profiles(self, mock_dbt_parser):
config = DATA_SOURCE_CONFIGS["snowflake"]
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
actual_config = create_ds_config(
Expand All @@ -166,11 +180,78 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
)
self.assertEqual(actual_config, config)

@patch("data_diff.dbt_parser.DbtParser.__new__")
def test_create_bigquery_ds_config_dbt_oauth(self, mock_dbt_parser):
config = DATA_SOURCE_CONFIGS["bigquery"]
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
actual_config = create_ds_config(
ds_config=self.db_type_data_source_schemas[config.type],
data_source_name=config.name,
dbt_parser=mock_dbt_parser,
)
self.assertEqual(actual_config, config)

@patch("data_diff.dbt_parser.DbtParser.__new__")
@patch("data_diff.cloud.data_source._get_data_from_bigquery_json")
def test_create_bigquery_ds_config_dbt_service_account(self, mock_get_data_from_bigquery_json, mock_dbt_parser):
config = DATA_SOURCE_CONFIGS["bigquery"]

mock_get_data_from_bigquery_json.return_value = json.loads(config.options["jsonKeyFile"])
mock_dbt_parser.get_connection_creds.return_value = (
{
"type": "bigquery",
"method": "service-account",
"project": config.options["projectId"],
"threads": 1,
"keyfile": "/some/path",
},
)

with patch(
"rich.prompt.Console.input",
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
):
actual_config = create_ds_config(
ds_config=self.db_type_data_source_schemas[config.type],
data_source_name=config.name,
dbt_parser=mock_dbt_parser,
)
self.assertEqual(actual_config, config)

@patch("data_diff.dbt_parser.DbtParser.__new__")
def test_create_bigquery_ds_config_dbt_service_account_json(self, mock_dbt_parser):
config = DATA_SOURCE_CONFIGS["bigquery"]

mock_dbt_parser.get_connection_creds.return_value = (
{
"type": "bigquery",
"method": "service-account-json",
"project": config.options["projectId"],
"threads": 1,
"keyfile_json": json.loads(config.options["jsonKeyFile"]),
},
)

with patch(
"rich.prompt.Console.input",
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
):
actual_config = create_ds_config(
ds_config=self.db_type_data_source_schemas[config.type],
data_source_name=config.name,
dbt_parser=mock_dbt_parser,
)
self.assertEqual(actual_config, config)

@patch("sys.stdout", new_callable=StringIO)
@patch("data_diff.dbt_parser.DbtParser.__new__")
def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input(self, mock_dbt_parser, mock_stdout):
config = DATA_SOURCE_CONFIGS[0]
options = copy.copy(config.options)
def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input(
self, mock_dbt_parser, mock_stdout
):
config = DATA_SOURCE_CONFIGS["snowflake"]
options = {**config.options, "type": "snowflake"}
options["database"] = options.pop("default_db")
account = options.pop("account")
mock_dbt_parser.get_connection_creds.return_value = (options,)
with patch(
Expand Down