Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks

if TYPE_CHECKING:
from socket import socket

from sqlalchemy.orm import Session

from airflow.callbacks.callback_requests import CallbackRequest
Expand Down Expand Up @@ -388,7 +390,7 @@ def _service_processor_sockets(self, timeout: float | None = 1.0):
"""
events = self.selector.select(timeout=timeout)
for key, _ in events:
socket_handler = key.data
socket_handler, on_close = key.data

# BrokenPipeError should be caught and treated as if the handler returned false, similar
# to EOF case
Expand All @@ -397,8 +399,9 @@ def _service_processor_sockets(self, timeout: float | None = 1.0):
except (BrokenPipeError, ConnectionResetError):
need_more = False
if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]
sock: socket = key.fileobj # type: ignore[assignment]
on_close(sock)
sock.close()

def _queue_requested_files_for_parsing(self) -> None:
"""Queue any files requested for parsing as requested by users via UI/API."""
Expand Down
33 changes: 17 additions & 16 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel):
bundle_path: Path
"""Passing bundle path around lets us figure out relative file path."""

requests_fd: int
callback_requests: list[CallbackRequest] = Field(default_factory=list)
type: Literal["DagFileParseRequest"] = "DagFileParseRequest"

Expand Down Expand Up @@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel):
def _parse_file_entrypoint():
import structlog

from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time import comms, task_runner

# Parse DAG file, send JSON back up!
comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
input=sys.stdin,
decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager](
body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
)

msg = comms_decoder.get_message()
msg = comms_decoder._get_response()
if not isinstance(msg, DagFileParseRequest):
raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}")
comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)

task_runner.SUPERVISOR_COMMS = comms_decoder
log = structlog.get_logger(logger_name="task")
Expand All @@ -125,7 +122,7 @@ def _parse_file_entrypoint():

result = _parse_file(msg, log)
if result is not None:
comms_decoder.send_request(log, result)
comms_decoder.send(result)


def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None:
Expand Down Expand Up @@ -266,20 +263,18 @@ def _on_child_started(
msg = DagFileParseRequest(
file=os.fspath(path),
bundle_path=bundle_path,
requests_fd=self._requests_fd,
callback_requests=callbacks,
)
self.send_msg(msg)
self.send_msg(msg, request_id=0)

def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override]
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override]
from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse

resp: BaseModel | None = None
dump_opts = {}
if isinstance(msg, DagFileParsingResult):
self.parsing_result = msg
return
if isinstance(msg, GetConnection):
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
Expand All @@ -301,18 +296,24 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: #
resp = self.client.variables.delete(msg.key)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
None,
request_id=req_id,
error=ErrorResponse(
detail={"status_code": 400, "message": "Unhandled request"},
),
)
return

if resp:
self.send_msg(resp, **dump_opts)
self.send_msg(resp, request_id=req_id, error=None, **dump_opts)

@property
def is_ready(self) -> bool:
if self._check_subprocess_exit() is None:
# Process still alive, def can't be finished yet
return False

return self._num_open_sockets == 0
return not self._open_sockets

def wait(self) -> int:
raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects")
126 changes: 74 additions & 52 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from collections.abc import Generator, Iterable
from contextlib import suppress
from datetime import datetime
from socket import socket
from traceback import format_exception
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, Union

Expand All @@ -43,6 +44,7 @@
from airflow.jobs.job import perform_heartbeat
from airflow.models.trigger import Trigger
from airflow.sdk.execution_time.comms import (
CommsDecoder,
ConnectionResult,
DagRunStateResult,
DRCount,
Expand All @@ -58,6 +60,7 @@
TICount,
VariableResult,
XComResult,
_RequestFrame,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader
from airflow.stats import Stats
Expand All @@ -70,8 +73,6 @@
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from socket import socket

from sqlalchemy.orm import Session
from structlog.typing import FilteringBoundLogger, WrappedLogger

Expand Down Expand Up @@ -181,7 +182,6 @@ class messages:
class StartTriggerer(BaseModel):
"""Tell the async trigger runner process to start, and where to send status update messages."""

requests_fd: int
type: Literal["StartTriggerer"] = "StartTriggerer"

class TriggerStateChanges(BaseModel):
Expand Down Expand Up @@ -295,7 +295,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
"""
TriggerRunnerSupervisor is responsible for monitoring the subprocess and marshalling DB access.

This class (which runs in the main process) is responsible for querying the DB, sending RunTrigger
This class (which runs in the main/sync process) is responsible for querying the DB, sending RunTrigger
workload messages to the subprocess, and collecting results and updating them in the DB.
"""

Expand Down Expand Up @@ -342,8 +342,8 @@ def start( # type: ignore[override]
):
proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs)

msg = messages.StartTriggerer(requests_fd=proc._requests_fd)
proc.send_msg(msg)
msg = messages.StartTriggerer()
proc.send_msg(msg, request_id=0)
return proc

@functools.cached_property
Expand All @@ -355,7 +355,7 @@ def client(self) -> Client:
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
return client

def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override]
def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override]
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskStatesResponse,
Expand Down Expand Up @@ -454,8 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -
else:
raise ValueError(f"Unknown message type {type(msg)}")

if resp:
self.send_msg(resp, **dump_opts)
self.send_msg(resp, request_id=req_id, error=None, **dump_opts)

def run(self) -> None:
"""Run synchronously and handle all database reads/writes."""
Expand Down Expand Up @@ -628,7 +627,7 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke
),
)

def _process_log_messages_from_subprocess(self) -> Generator[None, bytes, None]:
def _process_log_messages_from_subprocess(self) -> Generator[None, bytes | bytearray, None]:
import msgspec
from structlog.stdlib import NAME_TO_LEVEL

Expand Down Expand Up @@ -691,14 +690,60 @@ class TriggerDetails(TypedDict):
events: int


@attrs.define(kw_only=True)
class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]):
_async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer")
_async_reader: asyncio.StreamReader = attrs.field(alias="async_reader")

body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field(
factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
)

_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)

def _read_frame(self):
from asgiref.sync import async_to_sync

return async_to_sync(self._aread_frame)()

def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
from asgiref.sync import async_to_sync

return async_to_sync(self.asend)(msg)

async def _aread_frame(self):
len_bytes = await self._async_reader.readexactly(4)
len = int.from_bytes(len_bytes, byteorder="big")
if len >= 2**32:
raise OverflowError(f"Refusing to receive messages larger than 4GiB {len=}")

buffer = await self._async_reader.readexactly(len)
return self.resp_decoder.decode(buffer)

async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None:
frame = await self._aread_frame()
if frame.id != expect_id:
# Given the lock we take out in `asend`, this _shouldn't_ be possible, but I'd rather fail with
# this explicit error return the wrong type of message back to a Trigger
raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}")
return self._from_frame(frame)

async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
bytes = frame.as_bytes()

async with self._lock:
self._async_writer.write(bytes)

return await self._aget_response(frame.id)


class TriggerRunner:
"""
Runtime environment for all triggers.

Mainly runs inside its own thread, where it hands control off to an asyncio
event loop, but is also sometimes interacted with from the main thread
(where all the DB queries are done). All communication between threads is
done via Deques.
Mainly runs inside its own process, where it hands control off to an asyncio
event loop. All communication between this and it's (sync) supervisor is done via sockets
"""

# Maps trigger IDs to their running tasks and other info
Expand Down Expand Up @@ -726,10 +771,7 @@ class TriggerRunner:
# TODO: connect this to the parent process
log: FilteringBoundLogger = structlog.get_logger()

requests_sock: asyncio.StreamWriter
response_sock: asyncio.StreamReader

decoder: TypeAdapter[ToTriggerRunner]
comms_decoder: TriggerCommsDecoder

def __init__(self):
super().__init__()
Expand All @@ -740,7 +782,6 @@ def __init__(self):
self.events = deque()
self.failed_triggers = deque()
self.job_id = None
self.decoder = TypeAdapter(ToTriggerRunner)

def run(self):
"""Sync entrypoint - just run a run in an async loop."""
Expand Down Expand Up @@ -796,36 +837,21 @@ async def init_comms(self):
"""
from airflow.sdk.execution_time import task_runner

loop = asyncio.get_event_loop()
# Yes, we read and write to stdin! It's a socket, not a normal stdin.
reader, writer = await asyncio.open_connection(sock=socket(fileno=0))

comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor](
input=sys.stdin,
decoder=self.decoder,
self.comms_decoder = TriggerCommsDecoder(
async_writer=writer,
async_reader=reader,
)

task_runner.SUPERVISOR_COMMS = comms_decoder

async def connect_stdin() -> asyncio.StreamReader:
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
return reader

self.response_sock = await connect_stdin()
task_runner.SUPERVISOR_COMMS = self.comms_decoder

line = await self.response_sock.readline()
msg = await self.comms_decoder._aget_response(expect_id=0)

msg = self.decoder.validate_json(line)
if not isinstance(msg, messages.StartTriggerer):
raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}")

comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)
writer_transport, writer_protocol = await loop.connect_write_pipe(
lambda: asyncio.streams.FlowControlMixin(loop=loop),
comms_decoder.request_socket,
)
self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop)

async def create_triggers(self):
"""Drain the to_create queue and create all new triggers that have been requested in the DB."""
while self.to_create:
Expand Down Expand Up @@ -934,8 +960,6 @@ async def cleanup_finished_triggers(self) -> list[int]:
return finished_ids

async def sync_state_to_supervisor(self, finished_ids: list[int]):
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

# Copy out of our deques in threadsafe manner to sync state with parent
events_to_send = []
while self.events:
Expand All @@ -961,19 +985,17 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]):
if not finished_ids:
msg.finished = None

# Block triggers from making any requests for the duration of this
async with SUPERVISOR_COMMS.lock:
# Tell the monitor that we've finished triggers so it can update things
self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n")
line = await self.response_sock.readline()

if line == b"": # EoF received!
# Tell the monitor that we've finished triggers so it can update things
try:
resp = await self.comms_decoder.asend(msg)
except asyncio.IncompleteReadError:
if task := asyncio.current_task():
task.cancel("EOF - shutting down")
return
raise

resp = self.decoder.validate_json(line)
if not isinstance(resp, messages.TriggerStateSync):
raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}")
raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}")
self.to_create.extend(resp.to_create)
self.to_cancel.extend(resp.to_cancel)

Expand Down
Loading
Loading