Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from functools import cached_property
from pathlib import Path
from socket import socket
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union, overload
from uuid import UUID

import attrs
Expand Down Expand Up @@ -90,6 +90,12 @@
)
from airflow.sdk.exceptions import ErrorType

try:
from socket import recv_fds
except ImportError:
# Available on Unix and Windows (so "everywhere") but lets be safe
recv_fds = None # type: ignore[assignment]

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger

Expand Down Expand Up @@ -180,18 +186,41 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
bytes = frame.as_bytes()

self.socket.sendall(bytes)
if isinstance(msg, ResendLoggingFD):
if recv_fds is None:
return None
# We need special handling here! The server can't send us the fd number, as the number on the
# supervisor will be different to in this process, so we have to mutate the message ourselves here.
frame, fds = self._read_frame(maxfds=1)
resp = self._from_frame(frame)
if TYPE_CHECKING:
assert isinstance(resp, SentFDs)
resp.fds = fds
# Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not
# always be in the return type union
return resp # type: ignore[return-value]

return self._get_response()

def _read_frame(self):
@overload
def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...

@overload
def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ...

def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame:
"""
Get a message from the parent.

This will block until the message has been received.
"""
if self.socket:
self.socket.setblocking(True)
len_bytes = self.socket.recv(4)
fds = None
if maxfds:
len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds)
else:
len_bytes = self.socket.recv(4)

if len_bytes == b"":
raise EOFError("Request socket closed before length")
Expand All @@ -207,7 +236,10 @@ def _read_frame(self):
if nread == 0:
raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")

return self.resp_decoder.decode(buffer)
resp = self.resp_decoder.decode(buffer)
if maxfds:
return resp, fds or []
return resp

def _from_frame(self, frame) -> ReceiveMsgType | None:
from airflow.sdk.exceptions import AirflowRuntimeError
Expand Down Expand Up @@ -520,6 +552,11 @@ class OKResponse(BaseModel):
type: Literal["OKResponse"] = "OKResponse"


class SentFDs(BaseModel):
type: Literal["SentFDs"] = "SentFDs"
fds: list[int]


ToTask = Annotated[
Union[
AssetResult,
Expand All @@ -529,6 +566,7 @@ class OKResponse(BaseModel):
DRCount,
ErrorResponse,
PrevSuccessfulDagRunResult,
SentFDs,
StartupDetails,
TaskRescheduleStartDate,
TICount,
Expand Down Expand Up @@ -710,6 +748,10 @@ class DeleteVariable(BaseModel):
type: Literal["DeleteVariable"] = "DeleteVariable"


class ResendLoggingFD(BaseModel):
type: Literal["ResendLoggingFD"] = "ResendLoggingFD"


class SetRenderedFields(BaseModel):
"""Payload for setting RTIF for a task instance."""

Expand Down Expand Up @@ -829,6 +871,7 @@ class GetDRCount(BaseModel):
TaskState,
TriggerDagRun,
DeleteVariable,
ResendLoggingFD,
],
Field(discriminator="type"),
]
38 changes: 38 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
ResendLoggingFD,
RetryTask,
SentFDs,
SetRenderedFields,
SetXCom,
SkipDownstreamTasks,
Expand All @@ -115,6 +117,11 @@
)
from airflow.sdk.execution_time.secrets_masker import mask_secret

try:
from socket import send_fds
except ImportError:
send_fds = None # type: ignore[assignment]

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger, WrappedLogger

Expand Down Expand Up @@ -1218,6 +1225,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)
dump_opts = {"exclude_unset": True}
elif isinstance(msg, ResendLoggingFD):
# We need special handling here!
if send_fds is not None:
self._send_new_log_fd(req_id)
# Since we've sent the message, return. Nothing else in this ifelse/switch should return directly
return
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
Expand All @@ -1232,6 +1245,31 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:

self.send_msg(resp, request_id=req_id, error=None, **dump_opts)

def _send_new_log_fd(self, req_id: int) -> None:
if send_fds is None:
raise RuntimeError("send_fds is not available on this platform")
child_logs, read_logs = socketpair()

target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,)
if self.subprocess_logs_to_stdout:
target_loggers += (log,)

self.selector.register(
read_logs,
selectors.EVENT_READ,
make_buffered_socket_reader(
process_log_messages_from_subprocess(target_loggers), on_close=self._on_socket_closed
),
)
# We don't explicitly close the old log socket, that will get handled for us if/when the other end is
# closed (such as `sudo` would do for us automatically.) This also means that this feature _can_ be
# used outside of a exec context if it is useful, as we can then have multiple things sending us logs,
# such as the task process and it's launched subprocess.

frame = _ResponseFrame(id=req_id, body=SentFDs(fds=[child_logs.fileno()]).model_dump())
send_fds(self.stdin, [frame.as_bytes()], [child_logs.fileno()])
child_logs.close() # Close this end now.


def in_process_api_server():
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
Expand Down
14 changes: 14 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
GetTICount,
InactiveAssetsResult,
RescheduleTask,
ResendLoggingFD,
RetryTask,
SentFDs,
SetRenderedFields,
SkipDownstreamTasks,
StartupDetails,
Expand Down Expand Up @@ -659,6 +661,18 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"):
# entrypoint of re-exec process
msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"])

logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
if isinstance(logs, SentFDs):
from airflow.sdk.log import configure_logging

log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
configure_logging(enable_pretty_log=False, output=log_io, sending_to_supervisor=True)
else:
print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)

# We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
# on stdout
log.debug("Using serialized startup message from environment", msg=msg)
else:
# normal entry point
Expand Down
41 changes: 41 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from random import randint
from time import sleep
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import MagicMock, patch

import httpx
Expand Down Expand Up @@ -85,7 +86,9 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
ResendLoggingFD,
RetryTask,
SentFDs,
SetRenderedFields,
SetXCom,
SucceedTask,
Expand Down Expand Up @@ -236,6 +239,44 @@ def subprocess_main():
]
)

def test_reopen_log_fd(self, captured_logs, client_with_ti_start):
def subprocess_main():
# This is run in the subprocess!

# Ensure we follow the "protocol" and get the startup message before we do anything else
comms = CommsDecoder()
comms._get_response()

logs = comms.send(ResendLoggingFD())
assert isinstance(logs, SentFDs)
fd = os.fdopen(logs.fds[0], "w")
logging.root.info("Log on old socket")
json.dump({"level": "info", "event": "Log on new socket"}, fp=fd)

proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
),
client=client_with_ti_start,
target=subprocess_main,
)

rc = proc.wait()

assert rc == 0
assert captured_logs == unordered(
[
{"event": "Log on new socket", "level": "info", "logger": "task", "timestamp": mock.ANY},
{"event": "Log on old socket", "level": "info", "logger": "root", "timestamp": mock.ANY},
]
)

def test_subprocess_sigkilled(self, client_with_ti_start):
main_pid = os.getpid()

Expand Down