Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from __future__ import annotations

# [START hitl_tutorial]
import datetime
from typing import TYPE_CHECKING

import pendulum

Expand All @@ -29,6 +29,43 @@
HITLOperator,
)
from airflow.sdk import DAG, Param, task
from airflow.sdk.bases.notifier import BaseNotifier

if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context

# [START hitl_tutorial]


class LocalLogNotifier(BaseNotifier):
"""Simple notifier to demonstrate HITL notification without setup any connection."""

template_fields = ("message",)

def __init__(self, message: str) -> None:
self.message = message

def notify(self, context: Context) -> None:
self.log.info(self.message)


# [START htil_notifer]
hitl_request_callback = LocalLogNotifier(
message="""
[HITL]
Subject: {{ task.subject }}
Body: {{ task.body }}
Options: {{ task.options }}
Is Multiple Option: {{ task.multiple }}
Default Options: {{ task.defaults }}
Params: {{ task.params }}
"""
)
hitl_success_callback = LocalLogNotifier(
message="{% set task_id = task.task_id -%}{{ ti.xcom_pull(task_ids=task_id) }}"
)
hitl_failure_callback = LocalLogNotifier(message="Request to response to '{{ task.subject }}' failed")
# [END htil_notifer]

with DAG(
dag_id="example_hitl_operator",
Expand All @@ -41,6 +78,9 @@
task_id="wait_for_input",
subject="Please provide required information: ",
params={"information": Param("", type="string")},
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_entry_operator]

Expand All @@ -49,19 +89,52 @@
task_id="wait_for_option",
subject="Please choose one option to proceed: ",
options=["option 1", "option 2", "option 3"],
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_operator]

# [START howto_hitl_operator_muliple]
wait_for_multiple_options = HITLOperator(
task_id="wait_for_multiple_options",
subject="Please choose option to proceed: ",
options=["option 4", "option 5", "option 6"],
multiple=True,
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_operator_muliple]

# [START howto_hitl_operator_timeout]
wait_for_default_option = HITLOperator(
task_id="wait_for_default_option",
subject="Please choose option to proceed: ",
options=["option 7", "option 8", "option 9"],
defaults=["option 7"],
execution_timeout=datetime.timedelta(seconds=1),
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_operator_timeout]

# [START howto_hitl_approval_operator]
valid_input_and_options = ApprovalOperator(
task_id="valid_input_and_options",
subject="Are the following input and options valid?",
body="""
Input: {{ task_instance.xcom_pull(task_ids='wait_for_input', key='return_value')["params_input"]["information"] }}
Option: {{ task_instance.xcom_pull(task_ids='wait_for_option', key='return_value')["chosen_options"] }}
Input: {{ ti.xcom_pull(task_ids='wait_for_input')["params_input"]["information"] }}
Option: {{ ti.xcom_pull(task_ids='wait_for_option')["chosen_options"] }}
Multiple Options: {{ ti.xcom_pull(task_ids='wait_for_option')["chosen_options"] }}
Timeout Option: {{ ti.xcom_pull(task_ids='wait_for_option')["chosen_options"] }}
""",
defaults="Reject",
execution_timeout=datetime.timedelta(minutes=1),
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_approval_operator]

Expand All @@ -70,6 +143,9 @@
task_id="choose_a_branch_to_run",
subject="You're now allowed to proceeded. Please choose one task to run: ",
options=["task_1", "task_2", "task_3"],
notifiers=[hitl_request_callback],
on_success_callback=hitl_success_callback,
on_failure_callback=hitl_failure_callback,
)
# [END howto_hitl_branch_operator]

Expand All @@ -84,7 +160,7 @@ def task_2(): ...
def task_3(): ...

(
[wait_for_input, wait_for_option]
[wait_for_input, wait_for_option, wait_for_default_option, wait_for_multiple_options]
>> valid_input_and_options
>> choose_a_branch_to_run
>> [task_1(), task_2(), task_3()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.sdk.bases.notifier import BaseNotifier

if not AIRFLOW_V_3_1_PLUS:
raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.")


from collections.abc import Collection, Mapping
from collections.abc import Collection, Mapping, Sequence
from typing import TYPE_CHECKING, Any

from airflow.providers.standard.exceptions import HITLTimeoutError, HITLTriggerEventError
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
defaults: str | list[str] | None = None,
multiple: bool = False,
params: ParamsDict | dict[str, Any] | None = None,
notifiers: Sequence[BaseNotifier] | BaseNotifier | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -78,6 +80,9 @@ def __init__(
self.multiple = multiple

self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {})
self.notifiers: Sequence[BaseNotifier] = (
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or []
)

self.validate_defaults()

Expand Down Expand Up @@ -108,11 +113,16 @@ def execute(self, context: Context):
multiple=self.multiple,
params=self.serialized_params,
)

if self.execution_timeout:
timeout_datetime = utcnow() + self.execution_timeout
else:
timeout_datetime = None

self.log.info("Waiting for response")
for notifier in self.notifiers:
notifier(context)

# Defer the Human-in-the-loop response checking process to HITLTrigger
self.defer(
trigger=HITLTrigger(
Expand Down
6 changes: 6 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import datetime
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock

import pytest
from sqlalchemy import select
Expand Down Expand Up @@ -96,6 +97,8 @@ def test_validate_defaults_with_invalid_defaults(
)

def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
notifier = MagicMock()

with dag_maker("test_dag"):
task = HITLOperator(
task_id="hitl_test",
Expand All @@ -105,6 +108,7 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
defaults=["1"],
multiple=False,
params=ParamsDict({"input_1": 1}),
notifiers=[notifier],
)
dr = dag_maker.create_dagrun()
ti = dag_maker.run_ti(task.task_id, dr)
Expand All @@ -122,6 +126,8 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
assert hitl_detail_model.chosen_options is None
assert hitl_detail_model.params_input == {}

assert notifier.called is True

registered_trigger = session.scalar(
select(Trigger).where(Trigger.classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger")
)
Expand Down
Loading