Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ def bind_argument(arg, default_key):
if arg not in bound_args.arguments or bound_args.arguments[arg] is None:
self = args[0]
conn = self.get_connection(self.conn_id)
default_value = conn.extra_dejson.get(default_key)
extras = conn.extra_dejson
default_value = extras.get(default_key) or extras.get(
f"extra__azure_data_factory__{default_key}"
)
if not default_value:
raise AirflowException("Could not determine the targeted data factory.")

bound_args.arguments[arg] = conn.extra_dejson[default_key]
bound_args.arguments[arg] = default_value

bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name")
bind_argument("factory_name", "extra__azure_data_factory__factory_name")
bind_argument("resource_group_name", "resource_group_name")
bind_argument("factory_name", "factory_name")

return func(*bound_args.args, **bound_args.kwargs)

Expand Down Expand Up @@ -113,6 +116,23 @@ class AzureDataFactoryPipelineRunException(AirflowException):
"""An exception that indicates a pipeline run failed to complete."""


def get_field(extras: dict, field_name: str, strict: bool = False):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
backcompat_prefix = "extra__azure_data_factory__"
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
if prefixed_name in extras:
return extras[prefixed_name] or None
if strict:
raise KeyError(f"Field {field_name} not found in extras")


class AzureDataFactoryHook(BaseHook):
"""
A hook to interact with Azure Data Factory.
Expand All @@ -133,18 +153,12 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure_data_factory__tenantId": StringField(
lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()
),
"extra__azure_data_factory__subscriptionId": StringField(
lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()
),
"extra__azure_data_factory__resource_group_name": StringField(
"tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()),
"resource_group_name": StringField(
lazy_gettext("Resource Group Name"), widget=BS3TextFieldWidget()
),
"extra__azure_data_factory__factory_name": StringField(
lazy_gettext("Factory Name"), widget=BS3TextFieldWidget()
),
"factory_name": StringField(lazy_gettext("Factory Name"), widget=BS3TextFieldWidget()),
}

@staticmethod
Expand All @@ -168,10 +182,11 @@ def get_conn(self) -> DataFactoryManagementClient:
return self._conn

conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")
extras = conn.extra_dejson
tenant = get_field(extras, "tenantId")

try:
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
subscription_id = get_field(extras, "subscriptionId", strict=True)
except KeyError:
raise ValueError("A Subscription ID is required to connect to Azure Data Factory.")

Expand Down
18 changes: 9 additions & 9 deletions airflow/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
get_field,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -53,17 +54,16 @@ def get_link(
task_id=operator.task_id,
execution_date=dttm,
)

conn = BaseHook.get_connection(operator.azure_data_factory_conn_id)
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
conn_id = operator.azure_data_factory_conn_id
conn = BaseHook.get_connection(conn_id)
extras = conn.extra_dejson
subscription_id = get_field(extras, "subscriptionId")
if not subscription_id:
raise KeyError(f"Param subscriptionId not found in conn_id '{conn_id}'")
# Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory
# connection or passed directly to the operator.
resource_group_name = operator.resource_group_name or conn.extra_dejson.get(
"extra__azure_data_factory__resource_group_name"
)
factory_name = operator.factory_name or conn.extra_dejson.get(
"extra__azure_data_factory__factory_name"
)
resource_group_name = operator.resource_group_name or get_field(extras, "resource_group_name")
factory_name = operator.factory_name or get_field(extras, "factory_name")
url = (
f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}"
f"?factory=/subscriptions/{subscription_id}/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,22 @@ Tenant ID
Specify the Azure tenant ID used for the initial connection.
This is needed for *token credentials* authentication mechanism.
It can be left out to fall back on ``DefaultAzureCredential``.
Use the key ``extra__azure_data_factory__tenantId`` to pass in the tenant ID.
Use extra param ``tenantId`` to pass in the tenant ID.

Subscription ID
Specify the ID of the subscription used for the initial connection.
This is needed for all authentication mechanisms.
Use the key ``extra__azure_data_factory__subscriptionId`` to pass in the Azure subscription ID.
Use extra param ``subscriptionId`` to pass in the Azure subscription ID.

Factory Name (optional)
Specify the Azure Data Factory to interface with.
If not specified in the connection, this needs to be passed in directly to hooks, operators, and sensors.
Use the key ``extra__azure_data_factory__factory_name`` to pass in the factory name.
Use extra param ``factory_name`` to pass in the factory name.

Resource Group Name (optional)
Specify the Azure Resource Group Name under which the desired data factory resides.
If not specified in the connection, this needs to be passed in directly to hooks, operators, and sensors.
Use the key ``extra__azure_data_factory__resource_group_name`` to pass in the resource group name.
Use extra param ``resource_group_name`` to pass in the resource group name.


When specifying the connection in environment variable you should specify
Expand All @@ -86,8 +86,8 @@ Examples

.. code-block:: bash

export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?extra__azure_data_factory__tenantId=tenant+id&extra__azure_data_factory__subscriptionId=subscription+id&extra__azure_data_factory__resource_group_name=group+name&extra__azure_data_factory__factory_name=factory+name'
export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?tenantId=tenant+id&subscriptionId=subscription+id&resource_group_name=group+name&factory_name=factory+name'

.. code-block:: bash

export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?extra__azure_data_factory__tenantId=tenant+id&extra__azure_data_factory__subscriptionId=subscription+id'
export AIRFLOW_CONN_AZURE_DATA_FACTORY_DEFAULT='azure-data-factory://applicationid:serviceprincipalpassword@?tenantId=tenant+id&subscriptionId=subscription+id'
87 changes: 71 additions & 16 deletions tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from __future__ import annotations

import json
import os
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.models import FactoryListResponse
from pytest import fixture
from pytest import fixture, param

from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
Expand Down Expand Up @@ -56,10 +57,10 @@ def setup_module():
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__tenantId": "tenantId",
"extra__azure_data_factory__subscriptionId": "subscriptionId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
"tenantId": "tenantId",
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
}
),
)
Expand All @@ -68,9 +69,9 @@ def setup_module():
conn_type="azure_data_factory",
extra=json.dumps(
{
"extra__azure_data_factory__subscriptionId": "subscriptionId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
}
),
)
Expand All @@ -81,9 +82,9 @@ def setup_module():
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__tenantId": "tenantId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
"tenantId": "tenantId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
}
),
)
Expand All @@ -94,9 +95,9 @@ def setup_module():
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__subscriptionId": "subscriptionId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
}
),
)
Expand Down Expand Up @@ -149,8 +150,8 @@ def echo(_, resource_group_name=None, factory_name=None):
assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, FACTORY) == (RESOURCE_GROUP, FACTORY)

conn.extra_dejson = {
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
}
assert provide_targeted_factory(echo)(hook) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)
assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, None) == (RESOURCE_GROUP, DEFAULT_FACTORY)
Expand Down Expand Up @@ -653,3 +654,57 @@ def test_connection_failure_missing_tenant_id():

assert status is False
assert msg == "A Tenant ID is required when authenticating with Client ID and Secret."


@pytest.mark.parametrize(
"uri",
[
param(
"a://?extra__azure_data_factory__resource_group_name=abc"
"&extra__azure_data_factory__factory_name=abc",
id="prefix",
),
param("a://?resource_group_name=abc&factory_name=abc", id="no-prefix"),
],
)
@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = AzureDataFactoryHook("my_conn")
hook.delete_factory()
mock_connect.return_value.factories.delete.assert_called_with("abc", "abc")


@pytest.mark.parametrize(
"uri",
[
param(
"a://hi:yo@?extra__azure_data_factory__tenantId=ten"
"&extra__azure_data_factory__subscriptionId=sub",
id="prefix",
),
param("a://hi:yo@?tenantId=ten&subscriptionId=sub", id="no-prefix"),
],
)
@patch("airflow.providers.microsoft.azure.hooks.data_factory.ClientSecretCredential")
@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook._create_client")
def test_get_conn_backcompat_prefix_works(mock_create, mock_cred, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = AzureDataFactoryHook("my_conn")
hook.get_conn()
mock_cred.assert_called_with(client_id="hi", client_secret="yo", tenant_id="ten")
mock_create.assert_called_with(mock_cred.return_value, "sub")


@patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_conn")
def test_backcompat_prefix_both_prefers_short(mock_connect):
with patch.dict(
os.environ,
{
"AIRFLOW_CONN_MY_CONN": "a://?resource_group_name=non-prefixed"
"&extra__azure_data_factory__resource_group_name=prefixed"
},
):
hook = AzureDataFactoryHook("my_conn")
hook.delete_factory(factory_name="n/a")
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a")
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
AZURE_DATA_FACTORY_CONN_ID = "azure_data_factory_test"
PIPELINE_NAME = "pipeline1"
CONN_EXTRAS = {
"extra__azure_data_factory__subscriptionId": SUBSCRIPTION_ID,
"extra__azure_data_factory__tenantId": "my-tenant-id",
"extra__azure_data_factory__resource_group_name": "my-resource-group-name-from-conn",
"extra__azure_data_factory__factory_name": "my-factory-name-from-conn",
"subscriptionId": SUBSCRIPTION_ID,
"tenantId": "my-tenant-id",
"resource_group_name": "my-resource-group-name-from-conn",
"factory_name": "my-factory-name-from-conn",
}
PIPELINE_RUN_RESPONSE = {"additional_properties": {}, "run_id": "run_id"}
EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = (
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i
)

conn = AzureDataFactoryHook.get_connection("azure_data_factory_test")
conn_resource_group_name = conn.extra_dejson["extra__azure_data_factory__resource_group_name"]
conn_factory_name = conn.extra_dejson["extra__azure_data_factory__factory_name"]
conn_resource_group_name = conn.extra_dejson["resource_group_name"]
conn_factory_name = conn.extra_dejson["factory_name"]

assert url == (
EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK.format(
Expand Down