Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions task_sdk/tests/definitions/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,37 @@ def test_create_dag_while_active_context(self):
with DAG(dag_id="simple_dag"):
DAG(dag_id="dag2")
# No asserts needed, it just needs to not fail

def test_documentation_template_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""

@dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
def noop_pipeline():
"""
{% if True %}
Regular DAG documentation
{% endif %}
"""

dag = noop_pipeline()
assert dag.dag_id == "noop_pipeline"
assert "Regular DAG documentation" in dag.doc_md

def test_resolve_documentation_template_file_not_rendered(self, tmp_path):
"""Test that @dag uses function docs as doc_md for DAG object"""

raw_content = """
{% if True %}
External Markdown DAG documentation
{% endif %}
"""

path = tmp_path / "testfile.md"
path.write_text(raw_content)

@dag_decorator("test-dag", schedule=None, start_date=DEFAULT_DATE, doc_md=str(path))
def markdown_docs(): ...

dag = markdown_docs()
assert dag.dag_id == "test-dag"
assert dag.doc_md == raw_content
12 changes: 12 additions & 0 deletions task_sdk/tests/execution_time/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTas
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse
from airflow.utils import timezone

if task.has_dag():
if what.ti_context.dag_run.conf:
task.dag.params = what.ti_context.dag_run.conf # type: ignore[assignment]
ti = RuntimeTaskInstance.model_construct(
**what.ti.model_dump(exclude_unset=True),
task=task,
_ti_context_from_server=what.ti_context,
max_tries=what.ti_context.max_tries,
)
spy_agency.spy_on(parse, call_fake=lambda _: ti)
return ti

dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
if what.ti_context.dag_run.conf:
dag.params = what.ti_context.dag_run.conf # type: ignore[assignment]
Expand Down
122 changes: 119 additions & 3 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import json
import os
import uuid
from datetime import timedelta
from datetime import datetime, timedelta
from pathlib import Path
from socket import socketpair
from unittest import mock
Expand All @@ -29,6 +29,7 @@
import pytest
from uuid6 import uuid7

from airflow.decorators import task as task_decorator
from airflow.exceptions import (
AirflowException,
AirflowFailException,
Expand All @@ -37,9 +38,10 @@
AirflowTaskTerminated,
)
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import DAG, BaseOperator, Connection, get_current_context
from airflow.sdk import DAG, BaseOperator, Connection, dag as dag_decorator, get_current_context
from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import (
BundleInfo,
Expand Down Expand Up @@ -1198,6 +1200,15 @@ def execute(self, context):


class TestDagParamRuntime:
DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
"start_date": datetime.now(tz=timezone.utc),
"retries": 1,
"retry_delay": timedelta(minutes=1),
}
VALUE = 42

def test_dag_param_resolves_from_task(self, create_runtime_ti, mock_supervisor_comms, time_machine):
"""Test dagparam resolves on operator execution"""
instant = timezone.datetime(2024, 12, 3, 10, 0)
Expand Down Expand Up @@ -1252,7 +1263,7 @@ def execute(self, context):
)

def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine):
""" "Test dag param is retrieved from default config"""
"""Test that dag param is correctly resolved by operator"""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

Expand All @@ -1277,3 +1288,108 @@ def execute(self, context):
),
log=mock.ANY,
)

def test_dag_param_resolves(
self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context
):
"""Test that dag param is correctly resolved by operator"""

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

@dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3))
def dag_with_dag_params(value="NOTSET"):
@task_decorator
def dummy_task(val):
return val

class CustomOperator(BaseOperator):
def execute(self, context):
assert self.dag.params["value"] == "NOTSET"

_ = dummy_task(value)
custom_task = CustomOperator(task_id="task_with_dag_params")
self.operator = custom_task

dag_with_dag_params()

runtime_ti = create_runtime_ti(task=self.operator, dag_id="dag_with_dag_params")

run(runtime_ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SucceedTask(
state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
),
log=mock.ANY,
)

def test_dag_param_dagrun_parameterized(
self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context
):
"""Test that dag param is correctly overwritten when set in dag run"""

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

@dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3))
def dag_with_dag_params(value=self.VALUE):
@task_decorator
def dummy_task(val):
return val

assert isinstance(value, DagParam)

class CustomOperator(BaseOperator):
def execute(self, context):
assert self.dag.params["value"] == "new_value"

_ = dummy_task(value)
custom_task = CustomOperator(task_id="task_with_dag_params")
self.operator = custom_task

dag_with_dag_params()

runtime_ti = create_runtime_ti(
task=self.operator, dag_id="dag_with_dag_params", conf={"value": "new_value"}
)

run(runtime_ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SucceedTask(
state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
),
log=mock.ANY,
)

@pytest.mark.parametrize("value", [VALUE, 0])
def test_set_params_for_dag(
self, create_runtime_ti, mock_supervisor_comms, time_machine, make_ti_context, value
):
"""Test that dag param is correctly set when using dag decorator"""

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

@dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3))
def dag_with_param(value=value):
@task_decorator
def return_num(num):
return num

xcom_arg = return_num(value)
self.operator = xcom_arg.operator

dag_with_param()

runtime_ti = create_runtime_ti(task=self.operator, dag_id="dag_with_param", conf={"value": value})

run(runtime_ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_any_call(
msg=SucceedTask(
state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]
),
log=mock.ANY,
)
132 changes: 1 addition & 131 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
DagOwnerAttributes,
DagTag,
ExecutorLoader,
dag as dag_decorator,
get_asset_triggered_next_run_info,
)
from airflow.models.dag_version import DagVersion
Expand All @@ -73,7 +72,7 @@
from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
from airflow.sdk.definitions.param import DagParam, Param
from airflow.sdk.definitions.param import Param
from airflow.security import permissions
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.simple import (
Expand Down Expand Up @@ -2519,135 +2518,6 @@ def test_count_number_queries(self, tasks_count):
)


class TestDagDecorator:
DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
"start_date": timezone.utcnow(),
"retries": 1,
"retry_delay": timedelta(minutes=1),
}
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
VALUE = 42

def setup_method(self):
self.operator = None

def teardown_method(self):
clear_db_runs()

def test_documentation_template_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""

@dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
def noop_pipeline():
"""
{% if True %}
Regular DAG documentation
{% endif %}
"""

dag = noop_pipeline()
assert dag.dag_id == "noop_pipeline"
assert "Regular DAG documentation" in dag.doc_md

def test_resolve_documentation_template_file_not_rendered(self, tmp_path):
"""Test that @dag uses function docs as doc_md for DAG object"""

raw_content = """
{% if True %}
External Markdown DAG documentation
{% endif %}
"""

path = tmp_path / "testfile.md"
path.write_text(raw_content)

@dag_decorator("test-dag", schedule=None, start_date=DEFAULT_DATE, doc_md=str(path))
def markdown_docs(): ...

dag = markdown_docs()
assert dag.dag_id == "test-dag"
assert dag.doc_md == raw_content

def test_dag_param_resolves(self):
"""Test that dag param is correctly resolved by operator"""

@dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num

xcom_arg = return_num(value)
self.operator = xcom_arg.operator

dag = xcom_pass_to_op()

dr = dag.create_dagrun(
run_id="test",
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
logical_date=self.DEFAULT_DATE,
data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
run_after=self.DEFAULT_DATE,
state=State.RUNNING,
triggered_by=DagRunTriggeredByType.TEST,
)

self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == self.VALUE

def test_dag_param_dagrun_parameterized(self):
"""Test that dag param is correctly overwritten when set in dag run"""

@dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=self.VALUE):
@task_decorator
def return_num(num):
return num

assert isinstance(value, DagParam)

xcom_arg = return_num(value)
self.operator = xcom_arg.operator

dag = xcom_pass_to_op()
new_value = 52
dr = dag.create_dagrun(
run_id="test",
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
logical_date=self.DEFAULT_DATE,
data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
run_after=self.DEFAULT_DATE,
state=State.RUNNING,
conf={"value": new_value},
triggered_by=DagRunTriggeredByType.TEST,
)

self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == new_value

@pytest.mark.parametrize("value", [VALUE, 0])
def test_set_params_for_dag(self, value):
"""Test that dag param is correctly set when using dag decorator"""

@dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS)
def xcom_pass_to_op(value=value):
@task_decorator
def return_num(num):
return num

xcom_arg = return_num(value)
self.operator = xcom_arg.operator

dag = xcom_pass_to_op()
assert dag.params["value"] == value


@pytest.mark.parametrize(
"run_id",
["test-run-id"],
Expand Down