Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from uuid6 import uuid7

import airflow.models
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
from airflow.configuration import conf
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.collection import update_dag_parsing_results_in_db
Expand Down Expand Up @@ -80,6 +81,7 @@

from airflow.callbacks.callback_requests import CallbackRequest
from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.sdk.api.client import Client


class DagParsingStat(NamedTuple):
Expand Down Expand Up @@ -213,6 +215,9 @@ class DagFileProcessorManager(LoggingMixin):
_force_refresh_bundles: set[str] = attrs.field(factory=set, init=False)
"""List of bundles that need to be force refreshed in the next loop"""

_api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI)
"""API server to interact with Metadata DB"""

def register_exit_signals(self):
"""Register signals that stop child processes."""
signal.signal(signal.SIGINT, self._exit_gracefully)
Expand Down Expand Up @@ -867,6 +872,15 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo):
underlying_logger, processors=processors, logger_name="processor"
).bind(), logger_filehandle

@functools.cached_property
def client(self) -> Client:
from airflow.sdk.api.client import Client

client = Client(base_url=None, token="", dry_run=True, transport=self._api_server.transport)
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
return client

def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
id = uuid7()

Expand All @@ -881,6 +895,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
selector=self.selector,
logger=logger,
logger_filehandle=logger_filehandle,
client=self.client,
)

def _start_new_processes(self):
Expand Down
16 changes: 5 additions & 11 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import functools
import os
import sys
import traceback
Expand Down Expand Up @@ -239,6 +238,9 @@ class DagFileProcessorProcess(WatchedSubprocess):
parsing_result: DagFileParsingResult | None = None
decoder: ClassVar[TypeAdapter[ToManager]] = TypeAdapter[ToManager](ToManager)

client: Client
"""The HTTP client to use for communication with the API server."""

@classmethod
def start( # type: ignore[override]
cls,
Expand All @@ -247,9 +249,10 @@ def start( # type: ignore[override]
bundle_path: Path,
callbacks: list[CallbackRequest],
target: Callable[[], None] = _parse_file_entrypoint,
client: Client,
**kwargs,
) -> Self:
proc: Self = super().start(target=target, **kwargs)
proc: Self = super().start(target=target, client=client, **kwargs)
proc._on_child_started(callbacks, path, bundle_path)
return proc

Expand All @@ -267,15 +270,6 @@ def _on_child_started(
)
self.send_msg(msg)

@functools.cached_property
def client(self) -> Client:
from airflow.sdk.api.client import Client

client = Client(base_url=None, token="", dry_run=True, transport=in_process_api_server().transport)
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
return client

def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override]
from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse

Expand Down
3 changes: 3 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces
stdin=write_end,
requests_fd=123,
logger_filehandle=logger_filehandle,
client=MagicMock(),
)
if start_time:
ret.start_time = start_time
Expand Down Expand Up @@ -899,6 +900,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle):
selector=mock.ANY,
logger=mock_logger,
logger_filehandle=mock_filehandle,
client=mock.ANY,
),
mock.call(
id=mock.ANY,
Expand All @@ -908,6 +910,7 @@ def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle):
selector=mock.ANY,
logger=mock_logger,
logger_filehandle=mock_filehandle,
client=mock.ANY,
),
]
# And removed from the queue
Expand Down
70 changes: 57 additions & 13 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import structlog
from pydantic import TypeAdapter

from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.processor import (
Expand All @@ -40,6 +41,7 @@
from airflow.models import DagBag, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.api.client import Client
from airflow.sdk.execution_time.task_runner import CommsDecoder
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -67,6 +69,15 @@ def disable_load_example():
yield


@pytest.fixture
def inprocess_client():
"""Provides an in-process Client backed by a single API server."""
api = InProcessExecutionAPI()
client = Client(base_url=None, token="", dry_run=True, transport=api.transport)
client.base_url = "http://in-process.invalid/" # type: ignore[assignment]
return client


@pytest.mark.usefixtures("disable_load_example")
class TestDagFileProcessor:
def _process_file(
Expand Down Expand Up @@ -130,7 +141,7 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
assert "a.py" in resp.import_errors

def test_top_level_variable_access(
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
):
logger_filehandle = MagicMock()

Expand All @@ -144,7 +155,12 @@ def dag_in_a_fn():

monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc")
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -156,7 +172,7 @@ def dag_in_a_fn():
assert result.serialized_dags[0].dag_id == "test_abc"

def test_top_level_variable_access_not_found(
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch
self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
):
logger_filehandle = MagicMock()

Expand All @@ -168,7 +184,12 @@ def dag_in_a_fn():

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -180,7 +201,7 @@ def dag_in_a_fn():
if result.import_errors:
assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values()))

def test_top_level_variable_set(self, tmp_path: pathlib.Path):
def test_top_level_variable_set(self, tmp_path: pathlib.Path, inprocess_client):
from airflow.models.variable import Variable as VariableORM

logger_filehandle = MagicMock()
Expand All @@ -194,7 +215,12 @@ def dag_in_a_fn():

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -210,7 +236,7 @@ def dag_in_a_fn():
assert len(all_vars) == 1
assert all_vars[0].key == "mykey"

def test_top_level_variable_delete(self, tmp_path: pathlib.Path):
def test_top_level_variable_delete(self, tmp_path: pathlib.Path, inprocess_client):
from airflow.models.variable import Variable as VariableORM

logger_filehandle = MagicMock()
Expand All @@ -230,7 +256,12 @@ def dag_in_a_fn():

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -245,7 +276,9 @@ def dag_in_a_fn():
all_vars = session.query(VariableORM).all()
assert len(all_vars) == 0

def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
def test_top_level_connection_access(
self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, inprocess_client
):
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -259,7 +292,12 @@ def dag_in_a_fn():

monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", '{"conn_type": "aws"}')
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -270,7 +308,7 @@ def dag_in_a_fn():
assert result.import_errors == {}
assert result.serialized_dags[0].dag_id == "test_my_conn"

def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path):
def test_top_level_connection_access_not_found(self, tmp_path: pathlib.Path, inprocess_client):
logger_filehandle = MagicMock()

def dag_in_a_fn():
Expand All @@ -282,7 +320,12 @@ def dag_in_a_fn():

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
id=1,
path=path,
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=logger_filehandle,
client=inprocess_client,
)

while not proc.is_ready:
Expand All @@ -294,7 +337,7 @@ def dag_in_a_fn():
if result.import_errors:
assert "CONNECTION_NOT_FOUND" in next(iter(result.import_errors.values()))

def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path):
def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path, inprocess_client):
tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")

dag1_path = tmp_path.joinpath("dag1.py")
Expand All @@ -314,6 +357,7 @@ def test_import_module_in_bundle_root(self, tmp_path: pathlib.Path):
bundle_path=tmp_path,
callbacks=[],
logger_filehandle=MagicMock(),
client=inprocess_client,
)
while not proc.is_ready:
proc._service_subprocess(0.1)
Expand Down
Loading