Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions airflow/providers/amazon/aws/operators/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class StepFunctionStartExecutionOperator(BaseOperator):
class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
"""
An Operator that begins execution of an AWS Step Function State Machine.

Expand All @@ -50,10 +51,20 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
template_ext: Sequence[str] = ()
aws_hook_class = StepFunctionHook
template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
ui_color = "#f9c915"

def __init__(
Expand All @@ -62,8 +73,6 @@ def __init__(
state_machine_arn: str,
name: str | None = None,
state_machine_input: dict | str | None = None,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
waiter_max_attempts: int = 30,
waiter_delay: int = 60,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
Expand All @@ -73,18 +82,12 @@ def __init__(
self.state_machine_arn = state_machine_arn
self.name = name
self.input = state_machine_input
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)

if execution_arn is None:
if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")

self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
Expand All @@ -96,6 +99,8 @@ def execute(self, context: Context):
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
botocore_config=self.botocore_config,
verify=self.verify,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
Expand All @@ -110,7 +115,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
return event["execution_arn"]


class StepFunctionGetExecutionOutputOperator(BaseOperator):
class StepFunctionGetExecutionOutputOperator(AwsBaseOperator[StepFunctionHook]):
"""
An Operator that returns the output of an AWS Step Function State Machine execution.

Expand All @@ -121,30 +126,28 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
:ref:`howto/operator:StepFunctionGetExecutionOutputOperator`

:param execution_arn: ARN of the Step Function State Machine Execution
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("execution_arn",)
template_ext: Sequence[str] = ()
aws_hook_class = StepFunctionHook
template_fields: Sequence[str] = aws_template_fields("execution_arn")
ui_color = "#f9c915"

def __init__(
self,
*,
execution_arn: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
**kwargs,
):
def __init__(self, *, execution_arn: str, **kwargs):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
self.region_name = region_name

def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

execution_status = hook.describe_execution(self.execution_arn)
execution_status = self.hook.describe_execution(self.execution_arn)
response = None
if "output" in execution_status:
response = json.loads(execution_status["output"])
Expand Down
36 changes: 16 additions & 20 deletions airflow/providers/amazon/aws/sensors/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
from __future__ import annotations

import json
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class StepFunctionExecutionSensor(BaseSensorOperator):
class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]):
"""
Poll the Step Function State Machine Execution until it reaches a terminal state; fails if the task fails.

Expand All @@ -42,7 +42,16 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
:ref:`howto/sensor:StepFunctionExecutionSensor`

:param execution_arn: execution_arn to check the state of
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES = ("RUNNING",)
Expand All @@ -53,22 +62,13 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("SUCCEEDED",)

template_fields: Sequence[str] = ("execution_arn",)
template_ext: Sequence[str] = ()
aws_hook_class = StepFunctionHook
template_fields: Sequence[str] = aws_template_fields("execution_arn")
ui_color = "#66c3ff"

def __init__(
self,
*,
execution_arn: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
**kwargs,
):
def __init__(self, *, execution_arn: str, **kwargs):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
self.region_name = region_name

def poke(self, context: Context):
execution_status = self.hook.describe_execution(self.execution_arn)
Expand All @@ -93,7 +93,3 @@ def poke(self, context: Context):
def get_hook(self) -> StepFunctionHook:
"""Create and return a StepFunctionHook."""
return self.hook

@cached_property
def hook(self) -> StepFunctionHook:
return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
region_name: str | None = None,
**kwargs,
) -> None:
super().__init__(
serialized_fields={"execution_arn": execution_arn, "region_name": region_name},
Expand All @@ -56,7 +57,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return StepFunctionHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
Loading