Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os
import sys
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
Expand Down Expand Up @@ -197,7 +197,11 @@ def xcom_pull(

value = msg.value
if value is not None:
return value
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
return XCom.deserialize_value(value)
return default

def xcom_push(self, key: str, value: Any):
Expand All @@ -207,6 +211,12 @@ def xcom_push(self, key: str, value: Any):
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
value = XCom.serialize_value(value)

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
Expand Down Expand Up @@ -381,7 +391,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# - Update RTIF
# - Pre Execute
# etc
ti.task.execute(context) # type: ignore[attr-defined]
result = ti.task.execute(context) # type: ignore[attr-defined]
_push_xcom_if_needed(result, ti)

msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
Expand Down Expand Up @@ -436,6 +448,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
if ti.task.do_xcom_push:
xcom_value = result
else:
xcom_value = None

# If the task returns a result, push an XCom containing it.
if xcom_value is None:
return

# If the task has multiple outputs, push each output as a separate XCom.
if ti.task.multiple_outputs:
if not isinstance(xcom_value, Mapping):
raise TypeError(
f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
)
for key in xcom_value.keys():
if not isinstance(key, str):
raise TypeError(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for k, v in result.items():
ti.xcom_push(k, v)

# TODO: Use constant for XCom return key & use serialize_value from Task SDK
ti.xcom_push("return_value", result)


def finalize(log: Logger): ...


Expand Down
9 changes: 9 additions & 0 deletions task_sdk/tests/execution_time/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import sys
from unittest import mock

import pytest

Expand All @@ -31,3 +32,11 @@ def disable_capturing():
sys.stderr = sys.__stderr__
yield
sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err


@pytest.fixture
def mock_supervisor_comms():
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as supervisor_comms:
yield supervisor_comms
39 changes: 14 additions & 25 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

from __future__ import annotations

from unittest import mock

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
Expand Down Expand Up @@ -51,7 +49,7 @@ def test_convert_connection_result_conn():


class TestConnectionAccessor:
def test_getattr_connection(self):
def test_getattr_connection(self, mock_supervisor_comms):
"""
Test that the connection is fetched when accessed via __getattr__.

Expand All @@ -62,42 +60,33 @@ def test_getattr_connection(self):
# Conn from the supervisor / API Server
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
assert conn == expected_conn
expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
assert conn == expected_conn

def test_get_method_valid_connection(self):
def test_get_method_valid_connection(self, mock_supervisor_comms):
"""Test that the get method returns the requested connection using `conn.get`."""
accessor = ConnectionAccessor()
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result
mock_supervisor_comms.get_message.return_value = conn_result

conn = accessor.get("mysql_conn")
assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
conn = accessor.get("mysql_conn")
assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

def test_get_method_with_default(self):
def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default connection when the requested connection is not found."""
accessor = ConnectionAccessor()
default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"}
error_response = ErrorResponse(
error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"}
)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = error_response
mock_supervisor_comms.get_message.return_value = error_response

conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn
conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn
Loading