Skip to content

Commit

Permalink
Enable dbt Cloud provider to interact with single tenant instances (#…
Browse files Browse the repository at this point in the history
…24264)

* Enable provider to interact with single tenant

* Define single tenant arg on Operator

* Add test for single tenant endpoint

* Enable provider to interact with single tenant

* Define single tenant arg on Operator

* Add test for single tenant endpoint

* Code linting from black

* Code linting from black

* Pass tenant to dbtCloudHook in DbtCloudGetJobRunArtifactOperator class

* Make Tenant a connection-level setting

* Remove tenant arg from Operator

* Make tenant connection-level param that defaults to 'cloud'

* Remove tenant param from sensor

* Remove leftover param string from hook

* Update airflow/providers/dbt/cloud/hooks/dbt.py

Co-authored-by: Josh Fell <48934154+josh-fell@users.noreply.github.com>

* Parameterize test_init_hook to test single and multi tenant connections

* Integrate test simplification suggestion

* Add connection to TestDbtCloudJobRunSesnor

Co-authored-by: Josh Fell <48934154+josh-fell@users.noreply.github.com>
  • Loading branch information
epapineau and josh-fell authored Jun 6, 2022
1 parent 98b4e48 commit 7498fba
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
9 changes: 6 additions & 3 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,17 @@ class DbtCloudHook(HttpHook):
def get_ui_field_behaviour() -> Dict[str, Any]:
"""Builds custom field behavior for the dbt Cloud connection form in the Airflow UI."""
return {
"hidden_fields": ["host", "port", "schema", "extra"],
"relabeling": {"login": "Account ID", "password": "API Token"},
"hidden_fields": ["host", "port", "extra"],
"relabeling": {"login": "Account ID", "password": "API Token", "schema": "Tenant"},
"placeholders": {"schema": "Defaults to 'cloud'."},
}

def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
self.base_url = "https://cloud.getdbt.com/api/v2/accounts/"
tenant = self.connection.schema if self.connection.schema else 'cloud'

self.base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"

@cached_property
def connection(self) -> Connection:
Expand Down
26 changes: 22 additions & 4 deletions tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@

ACCOUNT_ID_CONN = "account_id_conn"
NO_ACCOUNT_ID_CONN = "no_account_id_conn"
SINGLE_TENANT_CONN = "single_tenant_conn"
DEFAULT_ACCOUNT_ID = 11111
ACCOUNT_ID = 22222
SINGLE_TENANT_SCHEMA = "single.tenant"
TOKEN = "token"
PROJECT_ID = 33333
JOB_ID = 4444
RUN_ID = 5555

BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/"
SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/"


class TestDbtCloudJobRunStatus:
Expand Down Expand Up @@ -119,15 +122,30 @@ def setup_class(self):
password=TOKEN,
)

# Connection with `schema` parameter set
schema_conn = Connection(
conn_id=SINGLE_TENANT_CONN,
conn_type=DbtCloudHook.conn_type,
login=DEFAULT_ACCOUNT_ID,
password=TOKEN,
schema=SINGLE_TENANT_SCHEMA,
)

db.merge_conn(account_id_conn)
db.merge_conn(no_account_id_conn)
db.merge_conn(schema_conn)

def test_init_hook(self):
hook = DbtCloudHook()
assert hook.dbt_cloud_conn_id == "dbt_cloud_default"
assert hook.base_url == BASE_URL
@pytest.mark.parametrize(
argnames="conn_id, url",
argvalues=[(ACCOUNT_ID_CONN, BASE_URL), (SINGLE_TENANT_CONN, SINGLE_TENANT_URL)],
ids=["multi-tenant", "single-tenant"],
)
def test_init_hook(self, conn_id, url):
hook = DbtCloudHook(conn_id)
assert hook.auth_type == TokenAuth
assert hook.method == "POST"
assert hook.dbt_cloud_conn_id == conn_id
assert hook.base_url == url

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand Down
8 changes: 8 additions & 0 deletions tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import pytest

from airflow.models.connection import Connection
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor
from airflow.utils import db

ACCOUNT_ID = 11111
RUN_ID = 5555
TOKEN = "token"


class TestDbtCloudJobRunSensor:
Expand All @@ -37,6 +40,11 @@ def setup_class(self):
poke_interval=15,
)

# Connection
conn = Connection(conn_id="dbt", conn_type=DbtCloudHook.conn_type, login=ACCOUNT_ID, password=TOKEN)

db.merge_conn(conn)

def test_init(self):
assert self.sensor.dbt_cloud_conn_id == "dbt"
assert self.sensor.run_id == RUN_ID
Expand Down

0 comments on commit 7498fba

Please sign in to comment.