Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,20 @@ def __init__(self, azure_data_factory_conn_id: str = default_conn_name):
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)

async def __aenter__(self):
"""Enter async context manager - returns self for use in async with blocks."""
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context manager - closes the async connection."""
await self.close()

async def close(self) -> None:
"""Close the async connection to Azure Data Factory."""
if self._async_conn is not None:
await self._async_conn.close()
self._async_conn = None

async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
"""Get async connection and connect to azure data factory."""
if self._async_conn is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,50 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

def _build_trigger_event(self, pipeline_status: str) -> TriggerEvent | None:
"""Build TriggerEvent based on pipeline status. Returns None if status is not terminal."""
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
return TriggerEvent({"status": "error", "message": f"Pipeline run {self.run_id} has Failed."})
if pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
return TriggerEvent(
{"status": "error", "message": f"Pipeline run {self.run_id} has been Cancelled."}
)
if pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
return TriggerEvent(
{"status": "success", "message": f"Pipeline run {self.run_id} has been Succeeded."}
)
return None

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Azure Data Factory, polls for the pipeline run status."""
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
executed_after_token_refresh = False
try:
while True:
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
executed_after_token_refresh = False
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
yield TriggerEvent(
{"status": "error", "message": f"Pipeline run {self.run_id} has Failed."}
async with AzureDataFactoryAsyncHook(
azure_data_factory_conn_id=self.azure_data_factory_conn_id
) as hook:
while True:
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
msg = f"Pipeline run {self.run_id} has been Cancelled."
yield TriggerEvent({"status": "error", "message": msg})
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
msg = f"Pipeline run {self.run_id} has been Succeeded."
yield TriggerEvent({"status": "success", "message": msg})
return
await asyncio.sleep(self.poke_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If exception is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if executed_after_token_refresh:
executed_after_token_refresh = False
event = self._build_trigger_event(pipeline_status)
if event:
yield event
return
await asyncio.sleep(self.poke_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If exception is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if not executed_after_token_refresh:
raise
await hook.refresh_conn()
executed_after_token_refresh = False
else:
raise
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

Expand Down Expand Up @@ -160,84 +166,93 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

def _build_trigger_event(self, pipeline_status: str) -> TriggerEvent | None:
"""Build TriggerEvent based on pipeline status. Returns None if status is not terminal."""
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
return TriggerEvent(
{
"status": "error",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
if pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
return TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
return None

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Azure Data Factory, polls for the pipeline run status."""
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
executed_after_token_refresh = True
if self.wait_for_termination:
while self.end_time > time.time():
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
executed_after_token_refresh = True
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
yield TriggerEvent(
{
"status": "error",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
async with AzureDataFactoryAsyncHook(
azure_data_factory_conn_id=self.azure_data_factory_conn_id
) as hook:
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
executed_after_token_refresh = True
if self.wait_for_termination:
while self.end_time > time.time():
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
return
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
executed_after_token_refresh = True
event = self._build_trigger_event(pipeline_status)
if event:
yield event
return
self.log.info(
"Sleeping for %s. The pipeline state is %s.",
self.check_interval,
pipeline_status,
)
return
self.log.info(
"Sleeping for %s. The pipeline state is %s.", self.check_interval, pipeline_status
)
await asyncio.sleep(self.check_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If exception is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if executed_after_token_refresh:
await asyncio.sleep(self.check_interval)
except ServiceRequestError:
# conn might expire during long running pipeline.
# If exception is caught, it tries to refresh connection once.
# If it still doesn't fix the issue,
# than the execute_after_token_refresh would still be False
# and an exception will be raised
if not executed_after_token_refresh:
raise
await hook.refresh_conn()
executed_after_token_refresh = False
else:
raise

yield TriggerEvent(
{
"status": "error",
"message": f"Timeout: The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
else:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status} status.",
"run_id": self.run_id,
}
)
except Exception as e:
self.log.exception(e)
if self.run_id:
try:
self.log.info("Cancelling pipeline run %s", self.run_id)
await hook.cancel_pipeline_run(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
yield TriggerEvent(
{
"status": "error",
"message": f"Timeout: The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
except Exception:
self.log.exception("Failed to cancel pipeline run %s", self.run_id)
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
else:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status} status.",
"run_id": self.run_id,
}
)
except Exception as e:
self.log.exception(e)
if self.run_id:
try:
self.log.info("Cancelling pipeline run %s", self.run_id)
await hook.cancel_pipeline_run(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
except Exception:
self.log.exception("Failed to cancel pipeline run %s", self.run_id)
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
Original file line number Diff line number Diff line change
Expand Up @@ -876,3 +876,38 @@ async def test_refresh_conn(self, mock_get_async_conn):
await hook.refresh_conn()
assert not hook._conn
assert mock_get_async_conn.called

@pytest.mark.asyncio
async def test_close_method(self):
"""Test close method properly closes the async connection"""
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
mock_conn = mock.AsyncMock()
hook._async_conn = mock_conn

await hook.close()

mock_conn.close.assert_called_once()
assert hook._async_conn is None

@pytest.mark.asyncio
async def test_close_method_when_conn_is_none(self):
"""Test close method does nothing when connection is None"""
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
hook._async_conn = None

# Should not raise any exception
await hook.close()
assert hook._async_conn is None

@pytest.mark.asyncio
async def test_context_manager_calls_close(self):
"""Test async context manager calls close on exit"""
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
mock_conn = mock.AsyncMock()
hook._async_conn = mock_conn

async with hook:
pass

mock_conn.close.assert_called_once()
assert hook._async_conn is None
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ class TestADFPipelineRunStatusSensorTrigger:
poke_interval=POKE_INTERVAL,
)

@pytest.mark.parametrize(
("pipeline_status", "expected_status", "expected_message"),
[
("Failed", "error", f"Pipeline run {RUN_ID} has Failed."),
("Cancelled", "error", f"Pipeline run {RUN_ID} has been Cancelled."),
("Succeeded", "success", f"Pipeline run {RUN_ID} has been Succeeded."),
],
)
def test_build_trigger_event_terminal_states(self, pipeline_status, expected_status, expected_message):
"""Test _build_trigger_event returns correct TriggerEvent for terminal states."""
event = self.TRIGGER._build_trigger_event(pipeline_status)
assert event is not None
assert event.payload["status"] == expected_status
assert event.payload["message"] == expected_message

@pytest.mark.parametrize(
"pipeline_status",
["Queued", "InProgress", "Canceling"],
)
def test_build_trigger_event_non_terminal_states(self, pipeline_status):
"""Test _build_trigger_event returns None for non-terminal states."""
event = self.TRIGGER._build_trigger_event(pipeline_status)
assert event is None

def test_adf_pipeline_run_status_sensors_trigger_serialization(self):
"""
Asserts that the TaskStateTrigger correctly serializes its arguments
Expand Down Expand Up @@ -186,6 +210,31 @@ class TestAzureDataFactoryTrigger:
end_time=AZ_PIPELINE_END_TIME,
)

@pytest.mark.parametrize(
("pipeline_status", "expected_status"),
[
("Failed", "error"),
("Cancelled", "error"),
("Succeeded", "success"),
],
)
def test_build_trigger_event_terminal_states(self, pipeline_status, expected_status):
"""Test _build_trigger_event returns correct TriggerEvent for terminal states."""
event = self.TRIGGER._build_trigger_event(pipeline_status)
assert event is not None
assert event.payload["status"] == expected_status
assert event.payload["run_id"] == AZ_PIPELINE_RUN_ID
assert f"The pipeline run {AZ_PIPELINE_RUN_ID} has {pipeline_status}." in event.payload["message"]

@pytest.mark.parametrize(
"pipeline_status",
["Queued", "InProgress", "Canceling"],
)
def test_build_trigger_event_non_terminal_states(self, pipeline_status):
"""Test _build_trigger_event returns None for non-terminal states."""
event = self.TRIGGER._build_trigger_event(pipeline_status)
assert event is None

def test_azure_data_factory_trigger_serialization(self):
"""Asserts that the AzureDataFactoryTrigger correctly serializes its arguments and classpath."""

Expand Down