Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions providers/http/src/airflow/providers/http/operators/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable

from aiohttp import BasicAuth
from requests import Response

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.http.triggers.http import HttpTrigger
from airflow.providers.http.triggers.http import HttpTrigger, serialize_auth_type
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,7 +123,7 @@ def __init__(
request_kwargs: dict[str, Any] | None = None,
http_conn_id: str = "http_default",
log_response: bool = False,
auth_type: type[AuthBase] | None = None,
auth_type: type[AuthBase] | type[BasicAuth] | None = None,
tcp_keep_alive: bool = True,
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
Expand Down Expand Up @@ -221,7 +222,7 @@ def execute_async(self, context: Context) -> None:
self.defer(
trigger=HttpTrigger(
http_conn_id=self.http_conn_id,
auth_type=self.auth_type,
auth_type=serialize_auth_type(self._resolve_auth_type()),
method=self.method,
endpoint=self.endpoint,
headers=self.headers,
Expand All @@ -231,6 +232,27 @@ def execute_async(self, context: Context) -> None:
method_name="execute_complete",
)

def _resolve_auth_type(self) -> type[AuthBase] | type[BasicAuth] | None:
"""
Resolve the authentication type for the HTTP request.

If auth_type is not explicitly set, attempt to infer it from the connection configuration.
For connections with login/password, default to BasicAuth.

:return: The resolved authentication type class, or None if no auth is needed.
"""
if self.auth_type is not None:
return self.auth_type

try:
conn = BaseHook.get_connection(self.http_conn_id)
if conn.login or conn.password:
return BasicAuth
except Exception as e:
self.log.warning("Failed to resolve auth type from connection: %s", e)

return None

def process_response(self, context: Context, response: Response | list[Response]) -> Any:
"""Process the response."""
from airflow.utils.operator_helpers import determine_kwargs
Expand Down Expand Up @@ -291,7 +313,7 @@ def paginate_async(
self.defer(
trigger=HttpTrigger(
http_conn_id=self.http_conn_id,
auth_type=self.auth_type,
auth_type=serialize_auth_type(self._resolve_auth_type()),
method=self.method,
**self._merge_next_page_parameters(next_page_params),
),
Expand Down
22 changes: 19 additions & 3 deletions providers/http/src/airflow/providers/http/triggers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import base64
import pickle
from collections.abc import AsyncIterator
from importlib import import_module
from typing import TYPE_CHECKING, Any

import aiohttp
Expand All @@ -35,6 +36,21 @@
from aiohttp.client_reqrep import ClientResponse


def serialize_auth_type(auth: str | type | None) -> str | None:
if auth is None:
return None
if isinstance(auth, str):
return auth
return f"{auth.__module__}.{auth.__qualname__}"


def deserialize_auth_type(path: str | None) -> type | None:
if path is None:
return None
module_path, cls_name = path.rsplit(".", 1)
return getattr(import_module(module_path), cls_name)


class HttpTrigger(BaseTrigger):
"""
HttpTrigger run on the trigger worker.
Expand All @@ -56,7 +72,7 @@ class HttpTrigger(BaseTrigger):
def __init__(
self,
http_conn_id: str = "http_default",
auth_type: Any = None,
auth_type: str | None = None,
method: str = "POST",
endpoint: str | None = None,
headers: dict[str, str] | None = None,
Expand All @@ -66,7 +82,7 @@ def __init__(
super().__init__()
self.http_conn_id = http_conn_id
self.method = method
self.auth_type = auth_type
self.auth_type = deserialize_auth_type(auth_type)
self.endpoint = endpoint
self.headers = headers
self.data = data
Expand All @@ -79,7 +95,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
{
"http_conn_id": self.http_conn_id,
"method": self.method,
"auth_type": self.auth_type,
"auth_type": serialize_auth_type(self.auth_type),
"endpoint": self.endpoint,
"headers": self.headers,
"data": self.data,
Expand Down
57 changes: 56 additions & 1 deletion providers/http/tests/unit/http/operators/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@
import contextlib
import json
import pickle
from types import SimpleNamespace
from unittest import mock
from unittest.mock import call, patch

import pytest
import tenacity
from aiohttp import BasicAuth
from requests import Response
from requests.models import RequestEncodingMixin

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.hooks import base
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.triggers.http import HttpTrigger
from airflow.providers.http.triggers.http import HttpTrigger, serialize_auth_type


@mock.patch.dict("os.environ", AIRFLOW_CONN_HTTP_EXAMPLE="http://www.example.com")
Expand Down Expand Up @@ -92,6 +95,7 @@ def test_filters_response(self, requests_mock):
result = operator.execute({})
assert result == {"value": 5}

@pytest.mark.db_test
def test_async_defer_successfully(self, requests_mock):
operator = HttpOperator(
task_id="test_HTTP_op",
Expand Down Expand Up @@ -186,6 +190,7 @@ def pagination_function(response: Response) -> dict | None:

assert result == [5, 10]

@pytest.mark.db_test
def test_async_pagination(self, requests_mock):
"""
Test that the HttpOperator calls asynchronously and repetitively
Expand Down Expand Up @@ -300,3 +305,53 @@ def pagination_function(response: Response) -> dict | None:
)

assert mock_run_with_advanced_retry.call_count == 2

def _capture_defer(self, monkeypatch):
captured = {}

def _fake_defer(self, *, trigger, method_name, **kwargs):
captured["trigger"] = trigger
captured["kwargs"] = kwargs

monkeypatch.setattr(HttpOperator, "defer", _fake_defer)
return captured

@pytest.mark.parametrize(
"login, password, auth_type, expect_cls",
[
("user", "password", None, BasicAuth),
(None, None, None, type(None)),
("user", "password", BasicAuth, BasicAuth),
],
)
def test_auth_type_is_serialised_as_string(self, monkeypatch, login, password, auth_type, expect_cls):
monkeypatch.setattr(
base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login=login, password=password)
)
captured = self._capture_defer(monkeypatch)

HttpOperator(task_id="test_HTTP_op", deferrable=True, auth_type=auth_type).execute(context={})

trigger = captured["trigger"]
kwargs = captured["trigger"].serialize()[1]

expected_str = serialize_auth_type(expect_cls) if expect_cls is not type(None) else None
assert kwargs["auth_type"] == expected_str

assert trigger.auth_type == expect_cls or (trigger.auth_type is None and expect_cls is type(None))

def test_resolve_auth_type_variants(self, monkeypatch):
monkeypatch.setattr(
base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login="user", password="password")
)
assert HttpOperator(task_id="test_HTTP_op_1")._resolve_auth_type() is BasicAuth

class DummyAuth:
def __init__(self, *_, **__): ...

assert HttpOperator(task_id="test_HTTP_op_2", auth_type=DummyAuth)._resolve_auth_type() is DummyAuth

monkeypatch.setattr(
base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login=None, password=None)
)
assert HttpOperator(task_id="test_HTTP_op_3")._resolve_auth_type() is None