Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
)
bundle_instance.initialize()

# Put bundle root on sys.path if needed. This allows the dag bundle to add
# code in util modules to be shared between files within the same bundle.
if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
sys.path.append(bundle_root)

dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
bag = DagBag(
dag_folder=dag_absolute_path,
Expand Down
49 changes: 49 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import json
import os
import textwrap
import uuid
from collections.abc import Iterable
from datetime import datetime, timedelta
Expand Down Expand Up @@ -274,6 +275,54 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id,
log.error.assert_has_calls([expected_error])


def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context):
"""Check that the bundle path is added to sys.path, so Dags can import shared modules."""
tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")

dag1_path = tmp_path.joinpath("path_test.py")
dag1_code = """
from util import NAME
from airflow.sdk import DAG
from airflow.sdk.bases.operator import BaseOperator
with DAG(NAME):
BaseOperator(task_id="a")
"""
dag1_path.write_text(textwrap.dedent(dag1_code))

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="a",
dag_id="dag_name",
run_id="c",
try_number=1,
),
dag_rel_path="path_test.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
requests_fd=0,
ti_context=make_ti_context(),
start_date=timezone.utcnow(),
)

with patch.dict(
os.environ,
{
"AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps(
[
{
"name": "my-bundle",
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
"kwargs": {"path": str(tmp_path), "refresh_interval": 1},
}
]
),
},
):
ti = parse(what, mock.Mock())

assert ti.task.dag.dag_id == "dag_name"


def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms):
"""Test that a task can transition to a deferred state."""
from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
Expand Down