Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Grab credentials from profiles.yml #491

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions data_diff/cloud/data_source.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import enum
import time
from typing import List, Optional
from typing import List, Optional, Union, overload

import pydantic
import rich
from rich.table import Table
from rich.prompt import Confirm, Prompt, FloatPrompt, IntPrompt, InvalidResponse
from typing_extensions import Literal

from .datafold_api import (
DatafoldAPI,
Expand All @@ -14,6 +14,7 @@
TDsConfig,
TestDataSourceStatus,
)
from ..dbt_parser import DbtParser


UNKNOWN_VALUE = "unknown_value"
Expand Down Expand Up @@ -49,8 +50,12 @@ def _validate_temp_schema(temp_schema: str):
raise ValueError("Temporary schema should have a format <database>.<schema>")


def create_ds_config(ds_config: TCloudApiDataSourceConfigSchema, data_source_name: str) -> TDsConfig:
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True)
def create_ds_config(
ds_config: TCloudApiDataSourceConfigSchema,
data_source_name: str,
dbt_parser: Optional[DbtParser] = None,
) -> TDsConfig:
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)

temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)
Expand All @@ -64,7 +69,41 @@ def create_ds_config(ds_config: TCloudApiDataSourceConfigSchema, data_source_nam
)


def _parse_ds_credentials(ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True):
@overload
def _cast_value(value: str, type_: Literal["integer"]) -> int:
...


@overload
def _cast_value(value: str, type_: Literal["boolean"]) -> bool:
...


@overload
def _cast_value(value: str, type_: Literal["string"]) -> str:
...


def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
if type_ == "integer":
return int(value)
elif type_ == "boolean":
return bool(value)
return value


def _parse_ds_credentials(
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
):
creds = {}
use_dbt_data = False
if dbt_parser is not None:
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
try:
creds = dbt_parser.get_connection_creds()[0]
except Exception as e:
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")

ds_options = {}
basic_required_fields = set(ds_config.config_schema.required)
for param_name, param_data in ds_config.config_schema.properties.items():
Expand All @@ -83,6 +122,14 @@ def _parse_ds_credentials(ds_config: TCloudApiDataSourceConfigSchema, only_basic
if default_value != UNKNOWN_VALUE:
input_values["default"] = default_value

if use_dbt_data:
value = creds.get(param_name, UNKNOWN_VALUE)
if value == UNKNOWN_VALUE:
rich.print(f'[red]Cannot extract "{param_name}" from dbt profiles.yml. Please, type it manually')
else:
ds_options[param_name] = _cast_value(value, type_)
continue

if type_ == "integer":
value = IntPrompt.ask(**input_values)
elif type_ == "boolean":
Expand Down Expand Up @@ -177,7 +224,7 @@ def _render_data_source_test_results(test_results: List[TDataSourceTestStage]) -
rich.print(table)


def get_or_create_data_source(api: DatafoldAPI) -> int:
def get_or_create_data_source(api: DatafoldAPI, dbt_parser: Optional[DbtParser] = None) -> int:
ds_configs = api.get_data_source_schema_config()
data_sources = api.get_data_sources()

Expand All @@ -198,10 +245,10 @@ def get_or_create_data_source(api: DatafoldAPI) -> int:
_render_data_source(data_source=ds, title=f'Found existing data source for name "{ds.name}"')
use_existing_ds = Confirm.ask("Would you like to continue with the existing data source?")
if not use_existing_ds:
return get_or_create_data_source(api=api)
return get_or_create_data_source(api=api, dbt_parser=dbt_parser)
return ds.id

ds_config = create_ds_config(ds_config, ds_name)
ds_config = create_ds_config(ds_config=ds_config, data_source_name=ds_name, dbt_parser=dbt_parser)
ds = api.create_data_source(ds_config)
data_source_url = f"{api.host}/settings/integrations/dwh/{ds.type}/{ds.id}"
_render_data_source(data_source=ds, title=f"Created a new data source with ID = {ds.id} ({data_source_url})")
Expand Down
Loading