Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
cluster_state = event["cluster_state"]
cluster_name = event["cluster_name"]

if cluster_state == ClusterStatus.State.ERROR:
if cluster_state == ClusterStatus.State(ClusterStatus.State.DELETING).name:
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")

self.log.info("%s completed successfully.", self.task_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,23 +316,24 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": ClusterStatus.State.DELETING,
"cluster": cluster,
"cluster_state": ClusterStatus.State(ClusterStatus.State.DELETING).name,
"cluster": Cluster.to_dict(cluster),
}
)
return
elif state == ClusterStatus.State.RUNNING:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": state,
"cluster": cluster,
"cluster_state": ClusterStatus.State(state).name,
"cluster": Cluster.to_dict(cluster),
}
)
return
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
else:
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
if self.delete_on_error and await self.safe_to_cancel():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
import contextlib
import logging
from asyncio import CancelledError, Future, sleep
from unittest import mock

Expand Down Expand Up @@ -50,6 +51,14 @@
TEST_GCP_CONN_ID = "google_cloud_default"
TEST_OPERATION_NAME = "name"
TEST_JOB_ID = "test-job-id"
TEST_RUNNING_CLUSTER = Cluster(
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.RUNNING),
)
TEST_ERROR_CLUSTER = Cluster(
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)


@pytest.fixture
Expand Down Expand Up @@ -158,28 +167,56 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c
@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch.object(DataprocClusterTrigger, "log")
async def test_async_cluster_triggers_on_success_should_execute_successfully(
self, mock_get_async_hook, cluster_trigger
):
future = asyncio.Future()
future.set_result(TEST_RUNNING_CLUSTER)
mock_get_async_hook.return_value.get_cluster.return_value = future

generator = cluster_trigger.run()
actual_event = await generator.asend(None)

expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State(ClusterStatus.State.RUNNING).name,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.fetch_cluster")
@mock.patch(
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
return_value=asyncio.Future(),
)
@mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
self, mock_log, mock_get_async_hook, cluster_trigger
self, mock_auth, mock_delete_cluster, mock_fetch_cluster, cluster_trigger, async_get_cluster, caplog
):
# Mock delete_cluster to return a Future
mock_delete_future = asyncio.Future()
mock_delete_future.set_result(None)
mock_get_async_hook.return_value.delete_cluster.return_value = mock_delete_future
mock_credentials = mock.MagicMock()
mock_credentials.universe_domain = "googleapis.com"

mock_cluster = mock.MagicMock()
mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
mock_auth.return_value = (mock_credentials, "project-id")

future = asyncio.Future()
future.set_result(mock_cluster)
mock_get_async_hook.return_value.get_cluster.return_value = future
mock_delete_cluster.return_value = asyncio.Future()
mock_delete_cluster.return_value.set_result(None)

mock_fetch_cluster.return_value = TEST_ERROR_CLUSTER

caplog.set_level(logging.INFO)

trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event

assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING
assert (
trigger_event.payload["cluster_state"] == ClusterStatus.State(ClusterStatus.State.DELETING).name
)

@pytest.mark.db_test
@pytest.mark.asyncio
Expand Down Expand Up @@ -321,31 +358,6 @@ async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
assert mock_delete_cluster.call_count == 0
mock_delete_cluster.assert_not_called()

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
async def test_async_cluster_triggers_on_success_should_execute_successfully(
self, mock_get_async_hook, cluster_trigger
):
mock_cluster = mock.MagicMock()
mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)

future = asyncio.Future()
future.set_result(mock_cluster)
mock_get_async_hook.return_value.get_cluster.return_value = future

generator = cluster_trigger.run()
actual_event = await generator.asend(None)

expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State.RUNNING,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event


class TestDataprocBatchTrigger:
def test_async_create_batch_trigger_serialization_should_execute_successfully(self, batch_trigger):
Expand Down