Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import json
from collections.abc import Iterable, Sequence
from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING

from dateutil import parser
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -135,19 +136,20 @@ def poke(self, context: Context) -> bool:

def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
hook = CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
cmd_parameters = (
["-d", self.composer_dag_id, "-o", "json"]
if self._composer_airflow_version < 3
else [self.composer_dag_id, "-o", "json"]
)
dag_runs_cmd = hook.execute_airflow_command(
dag_runs_cmd = self.hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
parameters=cmd_parameters,
)
cmd_result = hook.wait_command_execution_result(
cmd_result = self.hook.wait_command_execution_result(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
Expand All @@ -165,13 +167,27 @@ def _check_dag_runs_states(
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["logical_date"]).timestamp()
< parser.parse(
dag_run["execution_date" if self._composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
return True

def _get_composer_airflow_version(self) -> int:
"""Return Composer Airflow version."""
environment_obj = self.hook.get_environment(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
)
environment_config = Environment.to_dict(environment_obj)
image_version = environment_config["config"]["software_config"]["image_version"]
return int(image_version.split("airflow-")[1].split(".")[0])

def execute(self, context: Context) -> None:
self._composer_airflow_version = self._get_composer_airflow_version()
if self.deferrable:
start_date, end_date = self._get_logical_dates(context)
self.defer(
Expand All @@ -186,6 +202,7 @@ def execute(self, context: Context) -> None:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
composer_airflow_version=self._composer_airflow_version,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
Expand All @@ -195,3 +212,10 @@ def execute_complete(self, context: Context, event: dict):
if event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("DAG %s has executed successfully.", self.composer_dag_id)

@cached_property
def hook(self) -> CloudComposerHook:
return CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
composer_airflow_version: int = 2,
):
super().__init__()
self.project_id = project_id
Expand All @@ -181,6 +182,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.composer_airflow_version = composer_airflow_version

self.gcp_hook = CloudComposerAsyncHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -201,18 +203,24 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
"composer_airflow_version": self.composer_airflow_version,
},
)

async def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
cmd_parameters = (
["-d", self.composer_dag_id, "-o", "json"]
if self.composer_airflow_version < 3
else [self.composer_dag_id, "-o", "json"]
)
dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
parameters=cmd_parameters,
)
cmd_result = await self.gcp_hook.wait_command_execution_result(
project_id=self.project_id,
Expand All @@ -232,7 +240,9 @@ def _check_dag_runs_states(
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["logical_date"]).timestamp()
< parser.parse(
dag_run["execution_date" if self.composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
Expand Down
26 changes: 18 additions & 8 deletions providers/tests/google/cloud/sensors/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,41 @@
from datetime import datetime
from unittest import mock

import pytest

from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerDAGRunSensor

TEST_PROJECT_ID = "test_project_id"
TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
TEST_ENVIRONMENT_ID = "test_env_id"
TEST_JSON_RESULT = lambda state: json.dumps(
TEST_JSON_RESULT = lambda state, date_key: json.dumps(
[
{
"dag_id": "test_dag_id",
"run_id": "scheduled__2024-05-22T11:10:00+00:00",
"state": state,
"logical_date": "2024-05-22T11:10:00+00:00",
date_key: "2024-05-22T11:10:00+00:00",
"start_date": "2024-05-22T11:20:01.531988+00:00",
"end_date": "2024-05-22T11:20:11.997479+00:00",
}
]
)
TEST_EXEC_RESULT = lambda state: {
"output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}],
TEST_EXEC_RESULT = lambda state, date_key: {
"output": [{"line_number": 1, "content": TEST_JSON_RESULT(state, date_key)}],
"output_end": True,
"exit_info": {"exit_code": 0, "error": ""},
}


class TestCloudComposerDAGRunSensor:
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_ready(self, mock_hook, to_dict_mode):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("success")
def test_wait_ready(self, mock_hook, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"success", "execution_date" if composer_airflow_version < 3 else "logical_date"
)

task = CloudComposerDAGRunSensor(
task_id="task-id",
Expand All @@ -60,13 +65,17 @@ def test_wait_ready(self, mock_hook, to_dict_mode):
composer_dag_id="test_dag_id",
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_not_ready(self, mock_hook, to_dict_mode):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("running")
def test_wait_not_ready(self, mock_hook, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"running", "execution_date" if composer_airflow_version < 3 else "logical_date"
)

task = CloudComposerDAGRunSensor(
task_id="task-id",
Expand All @@ -76,5 +85,6 @@ def test_wait_not_ready(self, mock_hook, to_dict_mode):
composer_dag_id="test_dag_id",
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})
3 changes: 3 additions & 0 deletions providers/tests/google/cloud/triggers/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TEST_STATES = ["success"]
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_POLL_INTERVAL = 10
TEST_COMPOSER_AIRFLOW_VERSION = 3
TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
TEST_EXEC_RESULT = {
"output": [{"line_number": 1, "content": "test_content"}],
Expand Down Expand Up @@ -86,6 +87,7 @@ def dag_run_trigger(mock_conn):
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
)


Expand Down Expand Up @@ -140,6 +142,7 @@ def test_serialize(self, dag_run_trigger):
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": TEST_POLL_INTERVAL,
"composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
},
)
assert actual_data == expected_data