Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterActiveTrigger,
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict

Expand All @@ -45,21 +46,11 @@
from airflow.models import TaskInstance
from airflow.utils.context import Context

DEFAULT_CONN_ID = "aws_default"


class EcsBaseOperator(BaseOperator):
class EcsBaseOperator(AwsBaseOperator[EcsHook]):
"""This is the base operator for all Elastic Container Service operators."""

def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self) -> EcsHook:
"""Create and return an EcsHook."""
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
aws_hook_class = EcsHook

@cached_property
def client(self) -> boto3.client:
Expand Down Expand Up @@ -101,7 +92,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
(default: False)
"""

template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
Expand Down
23 changes: 7 additions & 16 deletions airflow/providers/amazon/aws/sensors/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
EcsTaskDefinitionStates,
EcsTaskStates,
)
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
import boto3

from airflow.utils.context import Context

DEFAULT_CONN_ID: str = "aws_default"


def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None:
if (current_state != target_state) and (current_state in failure_states):
Expand All @@ -45,18 +44,10 @@ def _check_failed(current_state, target_state, failure_states, soft_fail: bool)
raise AirflowException(message)


class EcsBaseSensor(BaseSensorOperator):
class EcsBaseSensor(AwsBaseSensor[EcsHook]):
"""Contains general sensor behavior for Elastic Container Service."""

def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)

@cached_property
def hook(self) -> EcsHook:
"""Create and return an EcsHook."""
return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
aws_hook_class = EcsHook

@cached_property
def client(self) -> boto3.client:
Expand All @@ -78,7 +69,7 @@ class EcsClusterStateSensor(EcsBaseSensor):
Success State. (Default: "FAILED" or "INACTIVE")
"""

template_fields: Sequence[str] = ("cluster_name", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("cluster_name", "target_state", "failure_states")

def __init__(
self,
Expand Down Expand Up @@ -116,7 +107,7 @@ class EcsTaskDefinitionStateSensor(EcsBaseSensor):
:param target_state: Success state to watch for. (Default: "ACTIVE")
"""

template_fields: Sequence[str] = ("task_definition", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("task_definition", "target_state", "failure_states")

def __init__(
self,
Expand Down Expand Up @@ -162,7 +153,7 @@ class EcsTaskStateSensor(EcsBaseSensor):
the Success State. (Default: "STOPPED")
"""

template_fields: Sequence[str] = ("cluster", "task", "target_state", "failure_states")
template_fields: Sequence[str] = aws_template_fields("cluster", "task", "target_state", "failure_states")

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/aws/triggers/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
Expand All @@ -66,6 +67,7 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
Expand All @@ -91,6 +93,7 @@ def __init__(
waiter_max_attempts: int,
aws_conn_id: str | None,
region_name: str | None = None,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_arn": cluster_arn},
Expand All @@ -104,6 +107,7 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
Expand Down
4 changes: 4 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/ecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------
.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
31 changes: 14 additions & 17 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
from airflow.providers.amazon.aws.operators.ecs import (
DEFAULT_CONN_ID,
EcsBaseOperator,
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
Expand Down Expand Up @@ -112,30 +111,28 @@ def test_initialise_operator(self, aws_conn_id, region_name):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)

assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else None)

@mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
@pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id, region_name):
"""Test initialize ``EcsHook`` and ``boto3.client``."""
mock_ecs_hook = mock_ecs_hook_cls.return_value
mock_conn = mock.MagicMock()
type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)

def test_initialise_operator_hook(self, aws_conn_id, region_name):
"""Test initialize operator."""
op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)

assert op.hook.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.hook.region_name == (region_name if region_name is not NOTSET else None)

hook = op.hook
assert op.hook is hook
mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
with mock.patch.object(EcsBaseOperator, "hook", new_callable=mock.PropertyMock) as m:
mocked_hook = mock.MagicMock(name="MockHook")
mocked_client = mock.MagicMock(name="Mocklient")
mocked_hook.conn = mocked_client
m.return_value = mocked_hook

client = op.client
mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
assert client == mock_conn
assert op.client is client
assert op.client == mocked_client
m.assert_called_once()


class TestEcsRunTaskOperator(EcsBaseTestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/amazon/aws/sensors/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.ecs import (
DEFAULT_CONN_ID,
EcsBaseSensor,
EcsClusterStates,
EcsClusterStateSensor,
Expand Down Expand Up @@ -79,7 +78,7 @@ def test_initialise_operator(self, aws_conn_id, region_name):
op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)

assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default")
assert op.region == (region_name if region_name is not NOTSET else None)

@pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
Expand Down