Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

--dbt add support for BQ service-account #609

Merged
merged 3 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dbt.config.renderer import ProfileRenderer

from data_diff.errors import (
DataDiffDbtBigQueryOauthOnlyError,
DataDiffDbtBigQueryUnsupportedMethodError,
DataDiffDbtConnectionNotImplementedError,
DataDiffDbtCoreNoRunnerError,
DataDiffDbtNoSuccessfulModelsInRunError,
Expand Down Expand Up @@ -319,17 +319,25 @@ def set_connection(self):
else:
raise DataDiffDbtSnowflakeSetConnectionError("Snowflake: unsupported auth method")
elif conn_type == "bigquery":
supported_methods = ["oauth", "service-account"]
method = credentials.get("method")
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
# this assumes that the user is auth'd via `gcloud auth application-default login`
if method is None or method != "oauth":
raise DataDiffDbtBigQueryOauthOnlyError("Oauth is the current method supported for Big Query.")
if method not in supported_methods:
raise DataDiffDbtBigQueryUnsupportedMethodError(
f"Method: {method} is not in the current methods supported for Big Query ({supported_methods})."
)

conn_info = {
"driver": conn_type,
"project": credentials.get("project"),
"dataset": credentials.get("dataset"),
}

self.threads = credentials.get("threads")
if method == supported_methods[1]:
conn_info["keyfile"] = credentials.get("keyfile")

elif conn_type == "duckdb":
conn_info = {
"driver": conn_type,
Expand Down
4 changes: 2 additions & 2 deletions data_diff/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class DataDiffDbtSnowflakeSetConnectionError(Exception):
"Raised when a dbt snowflake profile has unexpected values."


class DataDiffDbtBigQueryOauthOnlyError(Exception):
"Raised when trying to use a method other than oauth with BigQuery."
class DataDiffDbtBigQueryUnsupportedMethodError(Exception):
"Raised when trying to use an unsupported connection with BigQuery."


class DataDiffDbtRedshiftPasswordOnlyError(Exception):
Expand Down
23 changes: 18 additions & 5 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ def import_bigquery():
return bigquery


def import_bigquery_service_account():
from google.oauth2 import service_account

return service_account


class Mixin_MD5(AbstractMixin_MD5):
def md5_as_int(self, s: str) -> str:
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"


class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
Expand Down Expand Up @@ -144,8 +149,8 @@ class Dialect(BaseDialect, Mixin_Schema):
"BOOL": Boolean,
"JSON": JSON,
}
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
TYPE_ARRAY_RE = re.compile(r"ARRAY<(.+)>")
TYPE_STRUCT_RE = re.compile(r"STRUCT<(.+)>")
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}

def random(self) -> str:
Expand Down Expand Up @@ -173,7 +178,6 @@ def parse_type(
) -> ColType:
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
if isinstance(col_type, UnknownColType):

m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
if m:
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
Expand Down Expand Up @@ -207,9 +211,18 @@ class BigQuery(Database):
dialect = Dialect()

def __init__(self, project, *, dataset, **kw):
credentials = None
bigquery = import_bigquery()

self._client = bigquery.Client(project, **kw)
keyfile = kw.pop("keyfile", None)
if keyfile:
bigquery_service_account = import_bigquery_service_account()
credentials = bigquery_service_account.Credentials.from_service_account_file(
keyfile,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

self._client = bigquery.Client(project=project, credentials=credentials, **kw)
self.project = project
self.dataset = dataset

Expand Down
29 changes: 24 additions & 5 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from data_diff.cloud.datafold_api import TCloudApiOrgMeta
from data_diff.diff_tables import Algorithm
from data_diff.errors import (
DataDiffDbtBigQueryUnsupportedMethodError,
DataDiffCustomSchemaNoConfigError,
DataDiffDbtBigQueryOauthOnlyError,
DataDiffDbtConnectionNotImplementedError,
DataDiffDbtCoreNoRunnerError,
DataDiffDbtNoSuccessfulModelsInRunError,
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_set_connection_snowflake_key_and_password(self):

self.assertNotIsInstance(mock_self.connection, dict)

def test_set_connection_bigquery_success(self):
def test_set_connection_bigquery_oauth(self):
expected_driver = "bigquery"
expected_credentials = {
"method": "oauth",
Expand All @@ -288,17 +288,36 @@ def test_set_connection_bigquery_success(self):
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])

def test_set_connection_bigquery_not_oauth(self):
def test_set_connection_bigquery_svc_account(self):
expected_driver = "bigquery"
expected_credentials = {
"method": "not_oauth",
"method": "service-account",
"project": "a_project",
"dataset": "a_dataset",
"keyfile": "/some/path",
}
mock_self = Mock()
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)

DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("project"), expected_credentials["project"])
self.assertEqual(mock_self.connection.get("dataset"), expected_credentials["dataset"])
self.assertEqual(mock_self.connection.get("keyfile"), expected_credentials["keyfile"])

def test_set_connection_bigquery_not_supported(self):
expected_driver = "bigquery"
expected_credentials = {
"method": "not_supported",
"project": "a_project",
"dataset": "a_dataset",
}

mock_self = Mock()
mock_self.get_connection_creds.return_value = (expected_credentials, expected_driver)
with self.assertRaises(DataDiffDbtBigQueryOauthOnlyError):
with self.assertRaises(DataDiffDbtBigQueryUnsupportedMethodError):
DbtParser.set_connection(mock_self)

self.assertNotIsInstance(mock_self.connection, dict)
Expand Down