Skip to content
43 changes: 27 additions & 16 deletions airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from typing import Any, Callable, TypeVar, Union, cast

from asgiref.sync import sync_to_async
from azure.core.exceptions import ServiceRequestError
from azure.core.polling import LROPoller
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
Expand Down Expand Up @@ -214,6 +215,10 @@ def get_conn(self) -> DataFactoryManagementClient:

return self._conn

def refresh_conn(self) -> DataFactoryManagementClient:
self._conn = None
return self.get_conn()

@provide_targeted_factory
def get_factory(
self, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any
Expand Down Expand Up @@ -812,6 +817,7 @@ def wait_for_pipeline_run_status(
resource_group_name=resource_group_name,
)
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
executed_after_token_refresh = True

start_time = time.monotonic()

Expand All @@ -828,7 +834,14 @@ def wait_for_pipeline_run_status(
# Wait to check the status of the pipeline run based on the ``check_interval`` configured.
time.sleep(check_interval)

pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
try:
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
executed_after_token_refresh = True
except ServiceRequestError:
if executed_after_token_refresh:
self.refresh_conn()
continue
raise

return pipeline_run_status in expected_statuses

Expand Down Expand Up @@ -1132,6 +1145,10 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient:

return self._async_conn

async def refresh_conn(self) -> AsyncDataFactoryManagementClient:
self._conn = None
return await self.get_async_conn()

@provide_targeted_factory_async
async def get_pipeline_run(
self,
Expand All @@ -1149,11 +1166,8 @@ async def get_pipeline_run(
:param config: Extra parameters for the ADF client.
"""
client = await self.get_async_conn()
try:
pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id)
return pipeline_run
except Exception as e:
raise AirflowException(e)
pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id)
return pipeline_run

async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None
Expand All @@ -1165,16 +1179,13 @@ async def get_adf_pipeline_run_status(
:param resource_group_name: The resource group name.
:param factory_name: The factory name.
"""
try:
pipeline_run = await self.get_pipeline_run(
run_id=run_id,
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
return status
except Exception as e:
raise AirflowException(e)
pipeline_run = await self.get_pipeline_run(
run_id=run_id,
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
return status

@provide_targeted_factory_async
async def cancel_pipeline_run(
Expand Down
126 changes: 82 additions & 44 deletions airflow/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import time
from typing import Any, AsyncIterator

from azure.core.exceptions import ServiceRequestError

from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryAsyncHook,
AzureDataFactoryPipelineRunStatus,
Expand Down Expand Up @@ -68,24 +70,41 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Azure Data Factory, polls for the pipeline run status."""
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
executed_after_token_refresh = False
try:
while True:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
yield TriggerEvent(
{"status": "error", "message": f"Pipeline run {self.run_id} has Failed."}
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
elif pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
msg = f"Pipeline run {self.run_id} has been Cancelled."
yield TriggerEvent({"status": "error", "message": msg})
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
msg = f"Pipeline run {self.run_id} has been Succeeded."
yield TriggerEvent({"status": "success", "message": msg})
await asyncio.sleep(self.poke_interval)
executed_after_token_refresh = False
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
yield TriggerEvent(
{"status": "error", "message": f"Pipeline run {self.run_id} has Failed."}
)
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
msg = f"Pipeline run {self.run_id} has been Cancelled."
yield TriggerEvent({"status": "error", "message": msg})
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
msg = f"Pipeline run {self.run_id} has been Succeeded."
yield TriggerEvent({"status": "success", "message": msg})
return
await asyncio.sleep(self.poke_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If expcetion is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
continue
raise
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

Expand Down Expand Up @@ -147,33 +166,49 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
executed_after_token_refresh = True
if self.wait_for_termination:
while self.end_time > time.time():
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
yield TriggerEvent(
{
"status": "error",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
executed_after_token_refresh = True
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
yield TriggerEvent(
{
"status": "error",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
return
self.log.info(
"Sleeping for %s. The pipeline state is %s.", self.check_interval, pipeline_status
)
self.log.info(
"Sleeping for %s. The pipeline state is %s.", self.check_interval, pipeline_status
)
await asyncio.sleep(self.check_interval)
await asyncio.sleep(self.check_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If expcetion is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
continue
raise

yield TriggerEvent(
{
Expand All @@ -192,10 +227,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
)
except Exception as e:
if self.run_id:
await hook.cancel_pipeline_run(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
self.log.info("Unexpected error %s caught. Cancel pipeline run %s", str(e), self.run_id)
try:
await hook.cancel_pipeline_run(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
self.log.info("Unexpected error %s caught. Cancel pipeline run %s", str(e), self.run_id)
except Exception as err:
yield TriggerEvent({"status": "error", "message": str(err), "run_id": self.run_id})
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
36 changes: 17 additions & 19 deletions tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,14 @@ def test_backcompat_prefix_both_prefers_short(mock_connect):
mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a")


def test_refresh_conn(hook):
"""Test refresh_conn method _conn is reset and get_conn is called"""
with patch.object(hook, "get_conn") as mock_get_conn:
hook.refresh_conn()
assert not hook._conn
assert mock_get_conn.called


class TestAzureDataFactoryAsyncHook:
@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
Expand Down Expand Up @@ -780,16 +788,6 @@ async def test_get_adf_pipeline_run_status_cancelled(self, mock_get_pipeline_run
response = await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
assert response == mock_status

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_pipeline_run")
async def test_get_adf_pipeline_run_status_exception(self, mock_get_pipeline_run, mock_conn):
"""Test get_adf_pipeline_run_status function with exception"""
mock_get_pipeline_run.side_effect = Exception("Test exception")
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
with pytest.raises(AirflowException):
await hook.get_adf_pipeline_run_status(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME)

@pytest.mark.asyncio
@mock.patch("azure.mgmt.datafactory.models._models_py3.PipelineRun")
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
Expand All @@ -810,15 +808,6 @@ async def test_get_pipeline_run_exception_without_resource(
with pytest.raises(AirflowException):
await hook.get_pipeline_run(RUN_ID, None, DATAFACTORY_NAME)

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
async def test_get_pipeline_run_exception(self, mock_conn):
"""Test get_pipeline_run function with exception"""
mock_conn.return_value.pipeline_runs.get.side_effect = Exception("Test exception")
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
with pytest.raises(AirflowException):
await hook.get_pipeline_run(RUN_ID, RESOURCE_GROUP_NAME, DATAFACTORY_NAME)

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_connection")
async def test_get_async_conn(self, mock_connection):
Expand Down Expand Up @@ -958,3 +947,12 @@ def test_get_field_non_prefixed_extras(self):
assert get_field(extras, "factory_name", strict=True) == DATAFACTORY_NAME
with pytest.raises(KeyError):
get_field(extras, "non-existent-field", strict=True)

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_async_conn")
async def test_refresh_conn(self, mock_get_async_conn):
"""Test refresh_conn method _conn is reset and get_async_conn is called"""
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
await hook.refresh_conn()
assert not hook._conn
assert mock_get_async_conn.called
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,14 @@ async def test_adf_pipeline_run_status_sensors_trigger_cancelled(
assert TriggerEvent({"status": "error", "message": mock_message}) == actual

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.refresh_conn")
@mock.patch(f"{MODULE}.hooks.data_factory.AzureDataFactoryAsyncHook.get_adf_pipeline_run_status")
async def test_adf_pipeline_run_status_sensors_trigger_exception(self, mock_data_factory):
async def test_adf_pipeline_run_status_sensors_trigger_exception(
self, mock_data_factory, mock_refresh_token
):
"""Test EMR container sensors with raise exception"""
mock_data_factory.side_effect = Exception("Test exception")
mock_refresh_token.side_effect = Exception("Test exception")

task = [i async for i in self.TRIGGER.run()]
assert len(task) == 1
Expand Down