Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ def _init_pipeline_options(
is_dataflow_job_id_exist_callback,
)

@property
def extra_links_params(self) -> dict[str, Any]:
return {
"project_id": self.dataflow_config.project_id,
"region": self.dataflow_config.location,
"job_id": self.dataflow_job_id,
}

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Execute when the trigger fires - returns immediately.
Expand Down Expand Up @@ -443,13 +451,7 @@ def execute_on_dataflow(self, context: Context):
)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
location,
self.dataflow_job_id,
)
DataflowJobLink.persist(context=context, region=location)

if self.deferrable:
trigger_args = {
Expand Down Expand Up @@ -626,13 +628,7 @@ def execute_on_dataflow(self, context: Context):
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
if self.dataflow_job_name and self.dataflow_config.location:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
DataflowJobLink.persist(context=context)
if self.deferrable:
trigger_args = {
"job_id": self.dataflow_job_id,
Expand Down Expand Up @@ -795,14 +791,7 @@ def execute(self, context: Context):
variables=snake_case_pipeline_options,
process_line_callback=process_line_callback,
)

DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
DataflowJobLink.persist(context=context)
if dataflow_job_name and self.dataflow_config.location:
self.dataflow_hook.wait_for_done(
job_name=dataflow_job_name,
Expand Down
83 changes: 48 additions & 35 deletions providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pytest

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.apache.beam.operators.beam import (
BeamBasePipelineOperator,
BeamRunGoPipelineOperator,
Expand All @@ -32,6 +32,7 @@
)
from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.version import version

TASK_ID = "test-beam-operator"
Expand Down Expand Up @@ -233,13 +234,7 @@ def test_exec_dataflow_runner(
}
gcs_provide_file.assert_any_call(object_url=PY_FILE)
gcs_provide_file.assert_any_call(object_url=REQURIEMENTS_FILE)
persist_link_mock.assert_called_once_with(
op,
{},
expected_options["project"],
expected_options["region"],
op.dataflow_job_id,
)
persist_link_mock.assert_called_once_with(context={}, region="us-central1")
beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with(
variables=expected_options,
py_file=gcs_provide_file.return_value.__enter__.return_value.name,
Expand Down Expand Up @@ -446,13 +441,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"output": "gs://test/output",
"impersonateServiceAccount": TEST_IMPERSONATION_ACCOUNT,
}
persist_link_mock.assert_called_once_with(
op,
{},
expected_options["project"],
expected_options["region"],
op.dataflow_job_id,
)
persist_link_mock.assert_called_once_with(context={})
beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
variables=expected_options,
jar=gcs_provide_file.return_value.__enter__.return_value.name,
Expand Down Expand Up @@ -753,13 +742,7 @@ def test_exec_dataflow_runner_with_go_file(
"labels": {"foo": "bar", "airflow-version": TEST_VERSION},
"region": "us-central1",
}
persist_link_mock.assert_called_once_with(
op,
{},
expected_options["project"],
expected_options["region"],
op.dataflow_job_id,
)
persist_link_mock.assert_called_once_with(context={})
expected_go_file = "/tmp/apache-beam-go/main.go"
gcs_download_method.assert_called_once_with(
bucket_name="my-bucket", object_name="example/main.go", filename=expected_go_file
Expand Down Expand Up @@ -859,13 +842,7 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str)
worker_binary=expected_worker_binary,
process_line_callback=mock.ANY,
)
mock_persist_link.assert_called_once_with(
operator,
{},
dataflow_config.project_id,
dataflow_config.location,
operator.dataflow_job_id,
)
mock_persist_link.assert_called_once_with(context={})
wait_for_done_method.assert_called_once_with(
job_name=expected_job_name,
location=dataflow_config.location,
Expand Down Expand Up @@ -970,8 +947,20 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook
**self.default_op_kwargs,
)
magic_mock = mock.MagicMock()
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
if AIRFLOW_V_3_0_PLUS:
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
else:
exception_msg = (
"GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3."
" The method calls (only with context) needs to be removed after the Airflow 3 Migration"
" completed!"
)
with (
pytest.raises(TaskDeferred),
pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg),
):
op.execute(context=magic_mock)

dataflow_hook_mock.assert_called_once_with(
gcp_conn_id=dataflow_config.gcp_conn_id,
Expand Down Expand Up @@ -1005,8 +994,20 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
op = BeamRunPythonPipelineOperator(runner="DataflowRunner", **self.default_op_kwargs)
with pytest.raises(TaskDeferred):
op.execute(mock.MagicMock())
if AIRFLOW_V_3_0_PLUS:
with pytest.raises(TaskDeferred):
op.execute(mock.MagicMock())
else:
exception_msg = (
"GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3."
" The method calls (only with context) needs to be removed after the Airflow 3 Migration"
" completed!"
)
with (
pytest.raises(TaskDeferred),
pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg),
):
op.execute(mock.MagicMock())
op.on_kill()
dataflow_cancel_job.assert_not_called()

Expand Down Expand Up @@ -1075,8 +1076,20 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook
)
dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
magic_mock = mock.MagicMock()
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
if AIRFLOW_V_3_0_PLUS:
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
else:
exception_msg = (
"GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3."
" The method calls (only with context) needs to be removed after the Airflow 3 Migration"
" completed!"
)
with (
pytest.raises(TaskDeferred),
pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg),
):
op.execute(context=magic_mock)

dataflow_hook_mock.assert_called_once_with(
gcp_conn_id=dataflow_config.gcp_conn_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.google.cloud.links.base import BaseGoogleLink

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.utils.context import Context

ALLOY_DB_BASE_LINK = "/alloydb"
ALLOY_DB_CLUSTER_LINK = (
ALLOY_DB_BASE_LINK + "/locations/{location_id}/clusters/{cluster_id}?project={project_id}"
Expand All @@ -44,20 +38,6 @@ class AlloyDBClusterLink(BaseGoogleLink):
key = "alloy_db_cluster"
format_str = ALLOY_DB_CLUSTER_LINK

@staticmethod
def persist(
context: Context,
task_instance: BaseOperator,
location_id: str,
cluster_id: str,
project_id: str | None,
):
task_instance.xcom_push(
context,
key=AlloyDBClusterLink.key,
value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
)


class AlloyDBUsersLink(BaseGoogleLink):
"""Helper class for constructing AlloyDB users Link."""
Expand All @@ -66,36 +46,10 @@ class AlloyDBUsersLink(BaseGoogleLink):
key = "alloy_db_users"
format_str = ALLOY_DB_USERS_LINK

@staticmethod
def persist(
context: Context,
task_instance: BaseOperator,
location_id: str,
cluster_id: str,
project_id: str | None,
):
task_instance.xcom_push(
context,
key=AlloyDBUsersLink.key,
value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
)


class AlloyDBBackupsLink(BaseGoogleLink):
"""Helper class for constructing AlloyDB backups Link."""

name = "AlloyDB Backups"
key = "alloy_db_backups"
format_str = ALLOY_DB_BACKUPS_LINK

@staticmethod
def persist(
context: Context,
task_instance: BaseOperator,
project_id: str | None,
):
task_instance.xcom_push(
context,
key=AlloyDBBackupsLink.key,
value={"project_id": project_id},
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.sdk import BaseSensorOperator
from airflow.utils.context import Context

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
Expand All @@ -39,22 +42,82 @@ class BaseGoogleLink(BaseOperatorLink):
"""
Base class for all Google links.

When you inherit this class in a Link class;
- You can call the persist method to push data to the XCom to use it later in the get_link method.
- If you have an operator which inherit the GoogleCloudBaseOperator or BaseSensorOperator
You can define extra_links_params method in the operator to pass the operator properties
to the get_link method.

:meta private:
"""

name: ClassVar[str]
key: ClassVar[str]
format_str: ClassVar[str]

@property
def xcom_key(self) -> str:
# NOTE: in Airflow 3 we need to have xcom_key property in the Link class.
# Since we have the key property already, this is just a proxy property method to use same
# key as in Airflow 2.
return self.key

@classmethod
def persist(cls, context: Context, **value):
"""
Push arguments to the XCom to use later for link formatting at the `get_link` method.

Note: for Airflow 2 we need to call this function with context variable only
where we have the extra_links_params property method defined
"""
params = {}
# TODO: remove after Airflow v2 support dropped
if not AIRFLOW_V_3_0_PLUS:
common_params = getattr(context["task"], "extra_links_params", None)
if common_params:
params.update(common_params)

context["ti"].xcom_push(
key=cls.key,
value={
**params,
**value,
},
)

def get_config(self, operator, ti_key):
conf = {}
conf.update(getattr(operator, "extra_links_params", {}))
conf.update(XCom.get_value(key=self.key, ti_key=ti_key) or {})

# if the config did not define, return None to stop URL formatting
if not conf:
return None

# Add a default value for the 'namespace' parameter for backward compatibility.
# This is for datafusion
conf.setdefault("namespace", "default")
return conf

def get_link(
self,
operator: BaseOperator,
*,
ti_key: TaskInstanceKey,
) -> str:
conf = XCom.get_value(key=self.key, ti_key=ti_key)
if TYPE_CHECKING:
assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator))

conf = self.get_config(operator, ti_key)
if not conf:
return ""
if self.format_str.startswith("http"):
return self.format_str.format(**conf)
return BASE_LINK + self.format_str.format(**conf)
return self._format_link(**conf)

def _format_link(self, **kwargs):
try:
formatted_str = self.format_str.format(**kwargs)
if formatted_str.startswith("http"):
return formatted_str
return BASE_LINK + formatted_str
except KeyError:
return ""
Loading