Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions airflow/providers/microsoft/azure/hooks/adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
from azure.kusto.data.response import KustoResponseDataSetV2

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes


class AzureDataExplorerHook(BaseHook):
Expand Down Expand Up @@ -95,7 +94,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="azure_data_explorer")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down Expand Up @@ -148,7 +146,13 @@ def get_required_param(name: str) -> str:
value = extras.get(name)
if value:
warn_if_collison(name, backcompat_key)
if not value:
if not value and extras.get(backcompat_key):
warnings.warn(
f"`{backcompat_key}` is deprecated in azure connection extra,"
f" please use `{name}` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
value = extras.get(backcompat_key)
if not value:
raise AirflowException(f"Required connection parameter is missing: `{name}`")
Expand Down
37 changes: 24 additions & 13 deletions airflow/providers/microsoft/azure/hooks/base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
# under the License.
from __future__ import annotations

import warnings
from typing import Any

from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook


Expand Down Expand Up @@ -50,12 +51,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import StringField

return {
"extra__azure__tenantId": StringField(
lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()
),
"extra__azure__subscriptionId": StringField(
lazy_gettext("Azure Subscription ID"), widget=BS3TextFieldWidget()
),
"tenantId": StringField(lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()),
"subscriptionId": StringField(lazy_gettext("Azure Subscription ID"), widget=BS3TextFieldWidget()),
}

@staticmethod
Expand All @@ -79,8 +76,8 @@ def get_ui_field_behaviour() -> dict[str, Any]:
),
"login": "client_id (token credentials auth)",
"password": "secret (token credentials auth)",
"extra__azure__tenantId": "tenantId (token credentials auth)",
"extra__azure__subscriptionId": "subscriptionId (token credentials auth)",
"tenantId": "tenantId (token credentials auth)",
"subscriptionId": "subscriptionId (token credentials auth)",
},
}

Expand All @@ -96,10 +93,24 @@ def get_conn(self) -> Any:
:return: the authenticated client.
"""
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("extra__azure__tenantId") or conn.extra_dejson.get("tenantId")
subscription_id = conn.extra_dejson.get("extra__azure__subscriptionId") or conn.extra_dejson.get(
"subscriptionId"
)
tenant = conn.extra_dejson.get("tenantId")
if not tenant and conn.extra_dejson.get("extra__azure__tenantId"):
warnings.warn(
"`extra__azure__tenantId` is deprecated in azure connection extra, "
"please use `tenantId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
tenant = conn.extra_dejson.get("extra__azure__tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
if not subscription_id and conn.extra_dejson.get("extra__azure__subscriptionId"):
warnings.warn(
"`extra__azure__subscriptionId` is deprecated in azure connection extra, "
"please use `subscriptionId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
subscription_id = conn.extra_dejson.get("extra__azure__subscriptionId")

key_path = conn.extra_dejson.get("key_path")
if key_path:
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/microsoft/azure/hooks/container_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field
from airflow.providers.microsoft.azure.utils import get_field


class AzureContainerVolumeHook(BaseHook):
Expand Down Expand Up @@ -65,7 +65,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="azure_container_volume")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/microsoft/azure/hooks/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from airflow.exceptions import AirflowBadRequest
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field
from airflow.providers.microsoft.azure.utils import get_field


class AzureCosmosDBHook(BaseHook):
Expand Down Expand Up @@ -71,7 +71,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="azure_cosmos") # todo: remove when min airflow version >= 2.5
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down
29 changes: 25 additions & 4 deletions airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import inspect
import time
import warnings
from functools import wraps
from typing import Any, Callable, TypeVar, Union, cast

Expand All @@ -56,7 +57,7 @@
TriggerResource,
)

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.typing_compat import TypedDict

Expand Down Expand Up @@ -85,9 +86,15 @@ def bind_argument(arg, default_key):
self = args[0]
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
default_value = extras.get(default_key) or extras.get(
f"extra__azure_data_factory__{default_key}"
)
default_value = extras.get(default_key)
if not default_value and extras.get(f"extra__azure_data_factory__{default_key}"):
warnings.warn(
f"`extra__azure_data_factory__{default_key}` is deprecated in azure connection extra,"
f" please use `{default_key}` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
default_value = extras.get(f"extra__azure_data_factory__{default_key}")
if not default_value:
raise AirflowException("Could not determine the targeted data factory.")

Expand Down Expand Up @@ -139,6 +146,12 @@ def get_field(extras: dict, field_name: str, strict: bool = False):
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
if prefixed_name in extras:
warnings.warn(
f"`{prefixed_name}` is deprecated in azure connection extra,"
f" please use `{field_name}` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return extras[prefixed_name] or None
if strict:
raise KeyError(f"Field {field_name} not found in extras")
Expand Down Expand Up @@ -1086,6 +1099,14 @@ async def bind_argument(arg: Any, default_key: str) -> None:
default_value = extras.get(default_key) or extras.get(
f"extra__azure_data_factory__{default_key}"
)
if not default_value and extras.get(f"extra__azure_data_factory__{default_key}"):
warnings.warn(
f"`extra__azure_data_factory__{default_key}` is deprecated in azure connection extra,"
f" please use `{default_key}` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
default_value = extras.get(f"extra__azure_data_factory__{default_key}")
if not default_value:
raise AirflowException("Could not determine the targeted data factory.")

Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field
from airflow.providers.microsoft.azure.utils import get_field


class AzureDataLakeHook(BaseHook):
Expand Down Expand Up @@ -73,7 +73,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="azure_data_lake")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down
38 changes: 6 additions & 32 deletions airflow/providers/microsoft/azure/hooks/fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,14 @@
from __future__ import annotations

import warnings
from functools import wraps
from typing import IO, Any

from azure.storage.file import File, FileService

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook


def _ensure_prefixes(conn_type):
"""
Deprecated.

Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field

if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec


class AzureFileShareHook(BaseHook):
"""
Interacts with Azure FileShare Storage.
Expand Down Expand Up @@ -94,7 +64,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="azure_fileshare")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down Expand Up @@ -138,6 +107,11 @@ def check_for_conflict(key):
check_for_conflict(key)
elif key.startswith(backcompat_prefix):
short_name = key[len(backcompat_prefix) :]
warnings.warn(
f"`{key}` is deprecated in azure connection extra please use `{short_name}` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if short_name not in service_options: # prefer values provided with short name
service_options[short_name] = value
else:
Expand Down
32 changes: 0 additions & 32 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import logging
import os
from functools import wraps
from typing import Any, Union

from asgiref.sync import sync_to_async
Expand All @@ -51,36 +50,6 @@
AsyncCredentials = Union[AsyncClientSecretCredential, AsyncDefaultAzureCredential]


def _ensure_prefixes(conn_type):
"""
Deprecated.

Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field

if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec


class WasbHook(BaseHook):
"""
Interacts with Azure Blob Storage through the ``wasb://`` protocol.
Expand Down Expand Up @@ -124,7 +93,6 @@ def get_connection_form_widgets() -> dict[str, Any]:
}

@staticmethod
@_ensure_prefixes(conn_type="wasb")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
Expand Down
31 changes: 0 additions & 31 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,6 @@
from __future__ import annotations

import warnings
from functools import wraps


def _ensure_prefixes(conn_type):
"""
Deprecated.

Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field

if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec


def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ There are four ways to connect to Azure Container Volume using Airflow.
i.e. add specific credentials (client_id, secret) and subscription id to the Airflow connection.
2. Use a `Connection String
<https://docs.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/storage>`_
i.e. add connection string to ``extra__azure_container_volume__connection_string`` in the Airflow connection.
i.e. add connection string to ``connection_string`` in the Airflow connection.

Only one authorization method can be used at a time. If you need to manage multiple credentials or keys then you should
configure multiple connections.
Expand All @@ -61,7 +61,7 @@ Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in Azure connection.
The following parameters are all optional:

* ``extra__azure_container_volume__connection_string``: Connection string for use with connection string authentication.
* ``connection_string``: Connection string for use with connection string authentication.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand Down
Loading