Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
"Label",
"TaskGroup",
"dag",
"Connection",
"__version__",
]

__version__ = "1.0.0.dev1"

if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
Expand All @@ -43,6 +45,7 @@
"TaskGroup": ".definitions.taskgroup",
"EdgeModifier": ".definitions.edges",
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
}


Expand Down
17 changes: 15 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import sys
import uuid
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, TypeVar

import httpx
Expand All @@ -43,6 +44,8 @@
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -161,9 +164,19 @@ class ConnectionOperations:
def __init__(self, client: Client):
self.client = client

def get(self, conn_id: str) -> ConnectionResponse:
def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
"""Get a connection from the API server."""
resp = self.client.get(f"connections/{conn_id}")
try:
resp = self.client.get(f"connections/{conn_id}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"Connection not found",
conn_id=conn_id,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
return ConnectionResponse.model_validate_json(resp.read())


Expand Down
52 changes: 52 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import attrs


@attrs.define
class Connection:
"""
A connection to an external data source.

:param conn_id: The connection ID.
:param conn_type: The connection type.
:param description: The connection description.
:param host: The host.
:param login: The login.
:param password: The password.
:param schema: The schema.
:param port: The port number.
:param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON
encoded object.
"""

conn_id: str
conn_type: str
description: str | None = None
host: str | None = None
schema: str | None = None
login: str | None = None
password: str | None = None
port: int | None = None
extra: str | None = None

def get_uri(self): ...

def get_hook(self): ...
21 changes: 21 additions & 0 deletions task_sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from airflow.sdk.execution_time.comms import ErrorResponse


class AirflowRuntimeError(Exception):
def __init__(self, error: ErrorResponse):
self.error = error
super().__init__(f"{error.error.value}: {error.detail}")


class ErrorType(enum.Enum):
CONNECTION_NOT_FOUND = "CONNECTION_NOT_FOUND"
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
GENERIC_ERROR = "GENERIC_ERROR"
16 changes: 15 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType


class StartupDetails(BaseModel):
Expand All @@ -85,13 +86,26 @@ class XComResult(XComResponse):
class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResult"] = "ConnectionResult"

@classmethod
def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult:
# Exclude defaults to avoid sending unnecessary data
# Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
# to avoid sending unset fields (which are defaults in our case).
return cls(**connection_response.model_dump(exclude_defaults=True), type="ConnectionResult")


class VariableResult(VariableResponse):
type: Literal["VariableResult"] = "VariableResult"


class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
type: Literal["ErrorResponse"] = "ErrorResponse"


ToTask = Annotated[
Union[StartupDetails, XComResult, ConnectionResult, VariableResult],
Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse],
Field(discriminator="type"),
]

Expand Down
78 changes: 78 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import structlog

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.execution_time.comms import ConnectionResult


def _convert_connection_result_conn(conn_result: ConnectionResult):
from airflow.sdk.definitions.connection import Connection

# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))


def _get_connection(conn_id: str) -> Connection:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.connection`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, ConnectionResult)
return _convert_connection_result_conn(msg)


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

def __getattr__(self, conn_id: str) -> Any:
return _get_connection(conn_id)

def __repr__(self) -> str:
return "<ConnectionAccessor (dynamic access)>"

def __eq__(self, other):
if not isinstance(other, ConnectionAccessor):
return False
# All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor
return True

def get(self, conn_id: str, default_conn: Any = None) -> Any:
try:
return _get_connection(conn_id)
except AirflowRuntimeError as e:
if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
return default_conn
raise
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@

from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
IntermediateTIState,
TaskInstance,
TerminalTIState,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -689,7 +692,11 @@ def _handle_request(self, msg, log):
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
Expand Down
13 changes: 11 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ToSupervisor,
ToTask,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand All @@ -53,6 +54,9 @@ class RuntimeTaskInstance(TaskInstance):
"""The Task Instance context from the API server, if any."""

def get_template_context(self):
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?

# TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime()
context: dict[str, Any] = {
# From the Task Execution interface
Expand All @@ -63,6 +67,8 @@ def get_template_context(self):
"run_id": self.run_id,
"task": self.task,
"task_instance": self,
# TODO: Ensure that ti.log_url and such are available to use in context
# especially after removal of `conf` from Context.
"ti": self,
# "outlet_events": OutletEventAccessors(),
# "expanded_ti_count": expanded_ti_count,
Expand All @@ -73,14 +79,13 @@ def get_template_context(self):
# "prev_data_interval_end_success": get_prev_data_interval_end_success(),
# "prev_start_date_success": get_prev_start_date_success(),
# "prev_end_date_success": get_prev_end_date_success(),
# "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
# "var": {
# "json": VariableAccessor(deserialize_json=True),
# "value": VariableAccessor(deserialize_json=False),
# },
# "conn": ConnectionAccessor(),
"conn": ConnectionAccessor(),
}
if self._ti_context_from_server:
dag_run = self._ti_context_from_server.dag_run
Expand Down Expand Up @@ -108,6 +113,10 @@ def get_template_context(self):
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...

def xcom_push(self, *args, **kwargs): ...


def parse(what: StartupDetails) -> RuntimeTaskInstance:
# TODO: Task-SDK:
Expand Down
Loading