Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ class ErrorResponse(BaseModel):
]


class TaskState(BaseModel):
class NoResponseMessage:
"""A "marker" class/mixin that indicates this message type does not receive a response from the Supervisor."""


class TaskState(BaseModel, NoResponseMessage):
"""
Update a task's state.

Expand All @@ -196,13 +200,13 @@ class TaskState(BaseModel):
type: Literal["TaskState"] = "TaskState"


class DeferTask(TIDeferredStatePayload):
class DeferTask(TIDeferredStatePayload, NoResponseMessage):
"""Update a task instance state to deferred."""

type: Literal["DeferTask"] = "DeferTask"


class RescheduleTask(TIRescheduleStatePayload):
class RescheduleTask(TIRescheduleStatePayload, NoResponseMessage):
"""Update a task instance state to reschedule/up_for_reschedule."""

type: Literal["RescheduleTask"] = "RescheduleTask"
Expand All @@ -217,7 +221,7 @@ class GetXCom(BaseModel):
type: Literal["GetXCom"] = "GetXCom"


class SetXCom(BaseModel):
class SetXCom(BaseModel, NoResponseMessage):
key: str
value: Annotated[
# JsonValue can handle non JSON stringified dicts, lists and strings, which is better
Expand Down Expand Up @@ -265,7 +269,7 @@ class PutVariable(BaseModel):
type: Literal["PutVariable"] = "PutVariable"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add for this one too?



class SetRenderedFields(BaseModel):
class SetRenderedFields(BaseModel, NoResponseMessage):
"""Payload for setting RTIF for a task instance."""

# We are using a BaseModel here compared to server using RootModel because we
Expand Down
14 changes: 5 additions & 9 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def _get_connection(conn_id: str) -> Connection:
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
msg = SUPERVISOR_COMMS.get_message()
msg = SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

Expand All @@ -100,8 +99,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Variable:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
msg = SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

Expand Down Expand Up @@ -265,13 +263,12 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if name:
SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name))
msg = SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name))
elif uri:
SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri))
msg = SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri))
else:
raise ValueError("Either name or uri must be provided")

msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

Expand All @@ -289,8 +286,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
)
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id))
msg = SUPERVISOR_COMMS.get_message()
msg = SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id))

if TYPE_CHECKING:
assert isinstance(msg, PrevSuccessfulDagRunResult)
Expand Down
24 changes: 19 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datetime import datetime, timezone
from io import FileIO
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar, overload

import attrs
import lazy_object_proxy
Expand All @@ -39,6 +39,7 @@
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
NoResponseMessage,
RescheduleTask,
SetRenderedFields,
SetXCom,
Expand Down Expand Up @@ -248,7 +249,7 @@ def xcom_pull(

xcoms = []
for t in task_ids:
SUPERVISOR_COMMS.send_request(
msg = SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
Expand All @@ -259,7 +260,6 @@ def xcom_pull(
),
)

msg = SUPERVISOR_COMMS.get_message()
if not isinstance(msg, XComResult):
raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")

Expand Down Expand Up @@ -365,11 +365,13 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
# "sort of wrong default"
decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)

def get_message(self) -> ReceiveMsgType:
def get_message(self) -> ReceiveMsgType | None:
"""
Get a message from the parent.

This will block until the message has been received.

Most of the time you should call ``send_request`` which will call this method
"""
line = self.input.readline()
try:
Expand All @@ -384,11 +386,23 @@ def get_message(self) -> ReceiveMsgType:
self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)
return msg

def send_request(self, log: Logger, msg: SendMsgType):
@overload
def send_request(self, log: Logger, msg: NoResponseMessage) -> None: ...

@overload
def send_request(self, log: Logger, msg: SendMsgType) -> ReceiveMsgType: ...

def send_request(self, log: Logger, msg: SendMsgType | NoResponseMessage) -> ReceiveMsgType | None:
"""Send a request to the parent and return the response message (which might be None)."""
if TYPE_CHECKING:
assert isinstance(msg, BaseModel)
encoded_msg = msg.model_dump_json().encode() + b"\n"

log.debug("Sending request", json=encoded_msg)
self.request_socket.write(encoded_msg)
if isinstance(msg, NoResponseMessage):
return self.get_message()
return None


# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
Expand Down
16 changes: 11 additions & 5 deletions task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,14 @@ def _make_context_dict(


@pytest.fixture
def mock_supervisor_comms():
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as supervisor_comms:
yield supervisor_comms
def mock_supervisor_comms(monkeypatch):
from airflow.sdk.execution_time import task_runner

comms = mock.Mock(spec=task_runner.CommsDecoder)

def send_request(*args, **kwargs):
return comms.get_message()

comms.send_request.side_effect = send_request
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False)
return comms
6 changes: 3 additions & 3 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,16 +438,16 @@ def execute(self, context):
startup()
run(ti, log=mock.MagicMock())
expected_calls = [
mock.call.send_request(
mock.call(
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
),
mock.call.send_request(
mock.call(
msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
log=mock.ANY,
),
]
mock_supervisor_comms.assert_has_calls(expected_calls)
mock_supervisor_comms.send_request.assert_has_calls(expected_calls)


@pytest.mark.parametrize(
Expand Down
Loading