Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions providers/http/src/airflow/providers/http/sensors/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
if TYPE_CHECKING:
try:
from airflow.sdk.definitions.context import Context
from airflow.sensors.base import PokeReturnValue
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
request_kwargs: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
response_error_codes_allowlist: list[str] | None = None,
response_check: Callable[..., bool] | None = None,
response_check: Callable[..., bool | PokeReturnValue] | None = None,
extra_options: dict[str, Any] | None = None,
tcp_keep_alive: bool = True,
tcp_keep_alive_idle: int = 120,
Expand All @@ -129,7 +130,7 @@ def __init__(
self.deferrable = deferrable
self.request_kwargs = request_kwargs or {}

def poke(self, context: Context) -> bool:
def poke(self, context: Context) -> bool | PokeReturnValue:
from airflow.utils.operator_helpers import determine_kwargs

hook = HttpHook(
Expand Down Expand Up @@ -163,9 +164,9 @@ def poke(self, context: Context) -> bool:

return True

def execute(self, context: Context) -> None:
def execute(self, context: Context) -> Any:
if not self.deferrable or self.response_check:
super().execute(context=context)
return super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.timeout),
Expand Down
28 changes: 28 additions & 0 deletions providers/http/tests/unit/http/sensors/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.sensors.base import PokeReturnValue
from airflow.utils.timezone import datetime

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -65,6 +66,33 @@ def resp_check(_):
with pytest.raises(AirflowException, match="AirflowException raised here!"):
task.execute(context={})

@patch("airflow.providers.http.hooks.http.Session.send")
def test_poke_xcom_value(self, mock_session_send, create_task_of_operator):
"""
XCom value can be generated via response_check
"""
response = requests.Response()
response.status_code = 200
response._content = b'{"data": "somedata"}'
response.headers["Content-Type"] = "application/json"
mock_session_send.return_value = response

def resp_check(rsp):
return PokeReturnValue(is_done=True, xcom_value=rsp.json()["data"])

task = create_task_of_operator(
HttpSensor,
dag_id="http_sensor_poke_exception",
task_id="http_sensor_poke_exception",
http_conn_id="http_default",
endpoint="",
request_params={},
response_check=resp_check,
timeout=5,
poke_interval=1,
)
assert task.execute(context={}) == "somedata"

@patch("airflow.providers.http.hooks.http.Session.send")
def test_poke_continues_for_http_500_with_extra_options_check_response_false(
self,
Expand Down