Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DataflowJobStatusTrigger,
)
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.providers.google.version_compat import BaseSensorOperator
from airflow.providers.google.version_compat import BaseSensorOperator, PokeReturnValue

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -342,7 +342,7 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval

def poke(self, context: Context) -> bool:
def poke(self, context: Context) -> PokeReturnValue | bool:
if self.fail_on_terminal_state:
job = self.hook.get_job(
job_id=self.job_id,
Expand All @@ -359,8 +359,17 @@ def poke(self, context: Context) -> bool:
project_id=self.project_id,
location=self.location,
)
result = result if self.callback is None else self.callback(result)

if isinstance(result, PokeReturnValue):
return result

return result if self.callback is None else self.callback(result)
if bool(result):
return PokeReturnValue(
is_done=True,
xcom_value=result,
)
return False

def execute(self, context: Context) -> Any:
"""Airflow runs this method on the worker and defers using the trigger."""
Expand Down Expand Up @@ -464,7 +473,7 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval

def poke(self, context: Context) -> bool:
def poke(self, context: Context) -> PokeReturnValue | bool:
if self.fail_on_terminal_state:
job = self.hook.get_job(
job_id=self.job_id,
Expand All @@ -481,8 +490,16 @@ def poke(self, context: Context) -> bool:
project_id=self.project_id,
location=self.location,
)

return result if self.callback is None else self.callback(result)
result = result if self.callback is None else self.callback(result)
if isinstance(result, PokeReturnValue):
return result

if bool(result):
return PokeReturnValue(
is_done=True,
xcom_value=result,
)
return False

def execute(self, context: Context) -> Any:
"""Airflow runs this method on the worker and defers using the trigger."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from airflow.sdk import (
BaseOperatorLink,
BaseSensorOperator,
PokeReturnValue,
)
else:
from airflow.models import BaseOperatorLink # type: ignore[no-redef]
from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef]
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue # type: ignore[no-redef]

# Explicitly export these imports to protect them from being removed by linters
__all__ = [
Expand All @@ -65,4 +66,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"BaseOperator",
"BaseSensorOperator",
"BaseOperatorLink",
"PokeReturnValue",
]
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state):

results = task.poke(mock.MagicMock())

assert callback.return_value == results
assert callback.return_value == results.xcom_value

mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state):

results = task.poke(mock.MagicMock())

assert callback.return_value == results
assert callback.return_value == results.xcom_value

mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
Expand Down