Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from collections.abc import Callable, Generator
from contextlib import contextmanager, suppress
from datetime import datetime, timezone
from functools import lru_cache
from http import HTTPStatus
from socket import socket, socketpair
from typing import (
Expand Down Expand Up @@ -827,8 +826,10 @@ def _check_subprocess_exit(
return self._exit_code


@lru_cache
def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None:
_REMOTE_LOGGING_CONN_CACHE: dict[str, Connection | None] = {}


def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | None:
"""
Fetch and cache connection for remote logging.

Expand All @@ -837,18 +838,22 @@ def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None:
client: API client for making requests

Returns:
Connection object or None if not found
Connection object or None if not found.
"""
# Since we need to use the API Client directly, we can't use Connection.get as that would try to use
# SUPERVISOR_COMMS

# TODO: Store in the SecretsCache if its enabled - see #48858

if conn_id in _REMOTE_LOGGING_CONN_CACHE:
return _REMOTE_LOGGING_CONN_CACHE[conn_id]

backends = ensure_secrets_backend_loaded()
for secrets_backend in backends:
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
_REMOTE_LOGGING_CONN_CACHE[conn_id] = conn
return conn
except Exception:
log.exception(
Expand All @@ -862,8 +867,12 @@ def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None:
conn_result = ConnectionResult.from_conn_response(conn)
from airflow.sdk.definitions.connection import Connection

return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
return None
result: Connection | None = Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
else:
result = None

_REMOTE_LOGGING_CONN_CACHE[conn_id] = result
return result


@contextlib.contextmanager
Expand All @@ -878,7 +887,8 @@ def _remote_logging_conn(client: Client):
This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the
supervisor process when this is needed, so that doesn't exist yet.

This function uses @lru_cache for connection caching to avoid repeated API calls.
The connection details are fetched eagerly on every invocation to avoid retaining
per-task API client instances in global caches.
"""
from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler

Expand All @@ -887,8 +897,8 @@ def _remote_logging_conn(client: Client):
yield
return

# Use cached connection fetcher
conn = _get_remote_logging_conn(conn_id, client)
# Fetch connection details on-demand without caching the entire API client instance
conn = _fetch_remote_logging_conn(conn_id, client)

if conn:
key = f"AIRFLOW_CONN_{conn_id.upper()}"
Expand Down Expand Up @@ -1899,9 +1909,11 @@ def supervise(
if not dag_rel_path:
raise ValueError("dag_path is required")

close_client = False
if not client:
limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)
close_client = True

start = time.monotonic()

Expand All @@ -1920,24 +1932,29 @@ def supervise(

reset_secrets_masker()

process = ActivitySubprocess.start(
dag_rel_path=dag_rel_path,
what=ti,
client=client,
logger=logger,
bundle_info=bundle_info,
subprocess_logs_to_stdout=subprocess_logs_to_stdout,
)
try:
process = ActivitySubprocess.start(
dag_rel_path=dag_rel_path,
what=ti,
client=client,
logger=logger,
bundle_info=bundle_info,
subprocess_logs_to_stdout=subprocess_logs_to_stdout,
)

exit_code = process.wait()
end = time.monotonic()
log.info(
"Task finished",
task_instance_id=str(ti.id),
exit_code=exit_code,
duration=end - start,
final_state=process.final_state,
)
if log_path and log_file_descriptor:
log_file_descriptor.close()
return exit_code
exit_code = process.wait()
end = time.monotonic()
log.info(
"Task finished",
task_instance_id=str(ti.id),
exit_code=exit_code,
duration=end - start,
final_state=process.final_state,
)
return exit_code
finally:
if log_path and log_file_descriptor:
log_file_descriptor.close()
if close_client and client:
with suppress(Exception):
client.close()
41 changes: 41 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,47 @@ def mock_upload_to_remote(process_log, ti):
assert connection_available["conn_uri"] is not None, "Connection URI was None during upload"


def test_remote_logging_conn_caches_connection_not_client(monkeypatch):
"""Test that connection caching doesn't retain API client references."""
import gc
import weakref

from airflow.sdk import log as sdk_log
from airflow.sdk.execution_time import supervisor

class ExampleBackend:
def __init__(self):
self.calls = 0

def get_connection(self, conn_id: str):
self.calls += 1
from airflow.sdk.definitions.connection import Connection

return Connection(conn_id=conn_id, conn_type="example")

backend = ExampleBackend()
monkeypatch.setattr(supervisor, "ensure_secrets_backend_loaded", lambda: [backend])
monkeypatch.setattr(sdk_log, "load_remote_log_handler", lambda: object())
monkeypatch.setattr(sdk_log, "load_remote_conn_id", lambda: "test_conn")
monkeypatch.delenv("AIRFLOW_CONN_TEST_CONN", raising=False)

def noop_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(200)

clients = []
for _ in range(3):
client = make_client(transport=httpx.MockTransport(noop_request))
clients.append(weakref.ref(client))
with _remote_logging_conn(client):
pass
client.close()
del client

gc.collect()
assert backend.calls == 1, "Connection should be cached, not fetched multiple times"
assert all(ref() is None for ref in clients), "Client instances should be garbage collected"


def test_process_log_messages_from_subprocess(monkeypatch, caplog):
from airflow.sdk._shared.logging.structlog import PER_LOGGER_LEVELS

Expand Down
Loading