Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ triggers:
- integration-name: AWS Lambda
python-modules:
- airflow.providers.amazon.aws.triggers.lambda_function
- integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
python-modules:
- airflow.providers.amazon.aws.triggers.mwaa
- integration-name: Amazon Managed Service for Apache Flink
python-modules:
- airflow.providers.amazon.aws.triggers.kinesis_analytics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@ def get_waiter(
self,
waiter_name: str,
parameters: dict[str, str] | None = None,
config_overrides: dict[str, Any] | None = None,
deferrable: bool = False,
client=None,
) -> Waiter:
Expand All @@ -962,6 +963,9 @@ def get_waiter(
:param parameters: will scan the waiter config for the keys of that dict,
and replace them with the corresponding value. If a custom waiter has
such keys to be expanded, they need to be provided here.
Note: cannot be used if parameters are included in config_overrides
:param config_overrides: will update values of provided keys in the waiter's
config. Only specified keys will be updated.
:param deferrable: If True, the waiter is going to be an async custom waiter.
An async client must be provided in that case.
:param client: The client to use for the waiter's operations
Expand All @@ -970,14 +974,18 @@ def get_waiter(

if deferrable and not client:
raise ValueError("client must be provided for a deferrable waiter.")
if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
client = client or self._client
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
config = json.loads(config_file.read())
config: dict = json.loads(config_file.read())

if config_overrides is not None:
config["waiters"][waiter_name].update(config_overrides)
config = self._apply_parameters_value(config, waiter_name, parameters)
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
waiter_name
Expand Down
69 changes: 58 additions & 11 deletions providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from __future__ import annotations

from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.state import State
from airflow.utils.state import DagRunState

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
(templated)
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
``airflow.utils.state.State.success_states`` (templated)
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
AirflowException, default is ``airflow.utils.state.State.failed_states`` (templated)
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
:param max_retries: Number of times before returning the current state. (default: 720)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = MwaaHook
Expand All @@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
"external_dag_run_id",
"success_states",
"failure_states",
"deferrable",
"max_retries",
"poke_interval",
)

def __init__(
Expand All @@ -68,19 +89,25 @@ def __init__(
external_dag_run_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poke_interval: int = 60,
max_retries: int = 720,
**kwargs,
):
super().__init__(**kwargs)

self.success_states = set(success_states if success_states else State.success_states)
self.failure_states = set(failure_states if failure_states else State.failed_states)
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}

if len(self.success_states & self.failure_states):
raise AirflowException("allowed_states and failed_states must not have any values in common")
raise ValueError("success_states and failure_states must not have any values in common")

self.external_env_name = external_env_name
self.external_dag_id = external_dag_id
self.external_dag_run_id = external_dag_run_id
self.deferrable = deferrable
self.poke_interval = poke_interval
self.max_retries = max_retries

def poke(self, context: Context) -> bool:
self.log.info(
Expand All @@ -102,12 +129,32 @@ def poke(self, context: Context) -> bool:
# The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun

state = response["RestApiResponse"]["state"]
if state in self.success_states:
return True

if state in self.failure_states:
raise AirflowException(
f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} "
f"failed with state {state}."
f"failed with state: {state}"
)
return False

return state in self.success_states

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
validate_execute_complete_event(event)

def execute(self, context: Context):
if self.deferrable:
self.defer(
trigger=MwaaDagRunCompletedTrigger(
external_env_name=self.external_env_name,
external_dag_id=self.external_dag_id,
external_dag_run_id=self.external_dag_run_id,
success_states=self.success_states,
failure_states=self.failure_states,
waiter_delay=self.poke_interval,
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
else:
super().execute(context=context)
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):

:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
be updated.
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
Expand All @@ -77,6 +79,7 @@ def __init__(
return_value: Any,
waiter_delay: int,
waiter_max_attempts: int,
waiter_config_overrides: dict[str, Any] | None = None,
aws_conn_id: str | None,
region_name: str | None = None,
verify: bool | str | None = None,
Expand All @@ -91,6 +94,7 @@ def __init__(
self.failure_message = failure_message
self.status_message = status_message
self.status_queries = status_queries
self.waiter_config_overrides = waiter_config_overrides

self.return_key = return_key
self.return_value = return_value
Expand Down Expand Up @@ -140,7 +144,12 @@ def hook(self) -> AwsGenericHook:
async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.hook()
async with await hook.get_async_conn() as client:
waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client)
waiter = hook.get_waiter(
self.waiter_name,
deferrable=True,
client=client,
config_overrides=self.waiter_config_overrides,
)
await async_wait(
waiter,
self.waiter_delay,
Expand Down
129 changes: 129 additions & 0 deletions providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from collections.abc import Collection
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.utils.state import DagRunState

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
"""
Trigger when an MWAA Dag Run is complete.

:param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
(templated)
:param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
(templated)
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
*,
external_env_name: str,
external_dag_id: str,
external_dag_run_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
waiter_delay: int = 60,
waiter_max_attempts: int = 720,
aws_conn_id: str | None = None,
) -> None:
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}

if len(self.success_states & self.failure_states):
raise ValueError("success_states and failure_states must not have any values in common")

in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states

super().__init__(
serialized_fields={
"external_env_name": external_env_name,
"external_dag_id": external_dag_id,
"external_dag_run_id": external_dag_run_id,
"success_states": success_states,
"failure_states": failure_states,
},
waiter_name="mwaa_dag_run_complete",
waiter_args={
"Name": external_env_name,
"Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
"Method": "GET",
},
failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
status_message="State of DAG run",
status_queries=["RestApiResponse.state"],
return_key="dag_run_id",
return_value=external_dag_run_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
waiter_config_overrides={
"acceptors": _build_waiter_acceptors(
success_states=self.success_states,
failure_states=self.failure_states,
in_progress_states=in_progress_states,
)
},
)

def hook(self) -> AwsGenericHook:
return MwaaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


def _build_waiter_acceptors(
success_states: set[str], failure_states: set[str], in_progress_states: set[str]
) -> list:
def build_acceptor(dag_run_state: str, state_waiter_category: str):
return {
"matcher": "path",
"argument": "RestApiResponse.state",
"expected": dag_run_state,
"state": state_waiter_category,
}

acceptors = []
for state_set, state_waiter_category in (
(success_states, "success"),
(failure_states, "failure"),
(in_progress_states, "retry"),
):
for dag_run_state in state_set:
acceptors.append(build_acceptor(dag_run_state, state_waiter_category))

return acceptors
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,16 @@ async def async_wait(
last_response = error.last_response

if "terminal failure" in error_reason:
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response))
raise AirflowException(f"{failure_message}: {error}")
raise AirflowException(
f"{failure_message}: {_LazyStatusFormatter(status_args, last_response)}\n{error}"
)

if (
"An error occurred" in error_reason
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
raise AirflowException(f"{failure_message}: {error}")
raise AirflowException(f"{failure_message}\n{last_response}\n{error}")

log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"version": 2,
"waiters": {
"mwaa_dag_run_complete": {
"delay": 60,
"maxAttempts": 720,
"operation": "InvokeRestApi",
"acceptors": [
{
"matcher": "path",
"argument": "RestApiResponse.state",
"expected": "queued",
"state": "retry"
},
{
"matcher": "path",
"argument": "RestApiResponse.state",
"expected": "running",
"state": "retry"
},
{
"matcher": "path",
"argument": "RestApiResponse.state",
"expected": "success",
"state": "success"
},
{
"matcher": "path",
"argument": "RestApiResponse.state",
"expected": "failed",
"state": "failure"
}
]
}
}
}
Loading