Skip to content

Commit

Permalink
Standardize AWS ECS naming (#20332)
Browse files Browse the repository at this point in the history
* Rename ECS Hook and Operator
  • Loading branch information
ferruzzi authored Jan 4, 2022
1 parent f864d14 commit 9c0ba1b
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

"""
This is an example dag for ECSOperator.
This is an example dag for EcsOperator.
The task "hello_world" runs `hello-world` task in `c` cluster.
It overrides the command in the `hello-world-container` container.
Expand All @@ -26,7 +26,7 @@
import os

from airflow import DAG
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
from airflow.providers.amazon.aws.operators.ecs import EcsOperator

dag = DAG(
dag_id="ecs_fargate_dag",
Expand All @@ -40,7 +40,7 @@
dag.doc_md = __doc__

# [START howto_operator_ecs]
hello_world = ECSOperator(
hello_world = EcsOperator(
task_id="hello_world",
dag=dag,
aws_conn_id="aws_ecs",
Expand Down
19 changes: 18 additions & 1 deletion airflow/providers/amazon/aws/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,29 @@
#
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state
import warnings


class ECSOperatorError(Exception):
class EcsOperatorError(Exception):
"""Raise when ECS cannot handle the request."""

def __init__(self, failures: list, message: str):
self.failures = failures
self.message = message
super().__init__(message)


class ECSOperatorError(EcsOperatorError):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.exceptions.EcsOperatorError`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.exceptions.EcsOperatorError`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
75 changes: 62 additions & 13 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import sys
import time
import warnings
from collections import deque
from datetime import datetime, timedelta
from logging import Logger
Expand All @@ -29,7 +30,7 @@

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, XCom
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.exceptions import EcsOperatorError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.typing_compat import Protocol, runtime_checkable
Expand All @@ -38,7 +39,7 @@

def should_retry(exception: Exception):
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
if isinstance(exception, ECSOperatorError):
if isinstance(exception, EcsOperatorError):
return any(
quota_reason in failure['reason']
for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
Expand All @@ -48,10 +49,10 @@ def should_retry(exception: Exception):


@runtime_checkable
class ECSProtocol(Protocol):
class EcsProtocol(Protocol):
"""
A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on
:py:meth:`.ECSOperator.client`.
:py:meth:`.EcsOperator.client`.
.. seealso::
Expand Down Expand Up @@ -84,7 +85,7 @@ def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family:
...


class ECSTaskLogFetcher(Thread):
class EcsTaskLogFetcher(Thread):
"""
Fetches Cloudwatch log events with specific interval as a thread
and sends the log events to the info channel of the provided logger.
Expand Down Expand Up @@ -151,13 +152,13 @@ def stop(self):
self._event.set()


class ECSOperator(BaseOperator):
class EcsOperator(BaseOperator):
"""
Execute a task on AWS ECS (Elastic Container Service)
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:ECSOperator`
:ref:`howto/operator:EcsOperator`
:param task_definition: the task definition name on Elastic Container Service
:type task_definition: str
Expand Down Expand Up @@ -289,17 +290,17 @@ def __init__(
self.awslogs_region = region_name

self.hook: Optional[AwsBaseHook] = None
self.client: Optional[ECSProtocol] = None
self.client: Optional[EcsProtocol] = None
self.arn: Optional[str] = None
self.retry_args = quota_retry
self.task_log_fetcher: Optional[ECSTaskLogFetcher] = None
self.task_log_fetcher: Optional[EcsTaskLogFetcher] = None

@provide_session
def execute(self, context, session=None):
self.log.info(
'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster
)
self.log.info('ECSOperator overrides: %s', self.overrides)
self.log.info('EcsOperator overrides: %s', self.overrides)

self.client = self.get_hook().get_conn()

Expand Down Expand Up @@ -371,7 +372,7 @@ def _start_task(self, context):

failures = response['failures']
if len(failures) > 0:
raise ECSOperatorError(failures, response)
raise EcsOperatorError(failures, response)
self.log.info('ECS Task started: %s', response)

self.arn = response['tasks'][0]['taskArn']
Expand Down Expand Up @@ -430,11 +431,12 @@ def _wait_for_task_ended(self) -> None:
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _get_task_log_fetcher(self) -> ECSTaskLogFetcher:
def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
return ECSTaskLogFetcher(

return EcsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
Expand Down Expand Up @@ -509,3 +511,50 @@ def on_kill(self) -> None:
cluster=self.cluster, task=self.arn, reason='Task killed by the user'
)
self.log.info(response)


class ECSOperator(EcsOperator):
"""
This operator is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsOperator`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This operator is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsOperator`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class ECSTaskLogFetcher(EcsTaskLogFetcher):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class ECSProtocol(EcsProtocol):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.
"""

def __init__(self):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.",
DeprecationWarning,
stacklevel=2,
)
6 changes: 3 additions & 3 deletions docs/apache-airflow-providers-amazon/operators/ecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
under the License.
.. _howto/operator:ECSOperator:
.. _howto/operator:EcsOperator:

ECS Operator
============
Expand All @@ -30,14 +30,14 @@ Using Operator
--------------

Use the
:class:`~airflow.providers.amazon.aws.operators.ecs.ECSOperator`
:class:`~airflow.providers.amazon.aws.operators.ecs.EcsOperator`
to run a task defined in AWS ECS.

In the following example,
the task "hello_world" runs ``hello-world`` task in ``c`` cluster.
It overrides the command in the ``hello-world-container`` container.

Before using ECSOperator, *cluster* and *task definition* need to be created.
Before using EcsOperator, *cluster* and *task definition* need to be created.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
:language: python
Expand Down
42 changes: 21 additions & 21 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSTaskLogFetcher, should_retry
from airflow.providers.amazon.aws.exceptions import EcsOperatorError
from airflow.providers.amazon.aws.operators.ecs import EcsOperator, EcsTaskLogFetcher, should_retry

# fmt: off
RESPONSE_WITHOUT_FAILURES = {
Expand All @@ -55,7 +55,7 @@
# fmt: on


class TestECSOperator(unittest.TestCase):
class TestEcsOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsBaseHook')
def set_up_operator(self, aws_hook_mock, **kwargs):
self.aws_hook_mock = aws_hook_mock
Expand All @@ -77,7 +77,7 @@ def set_up_operator(self, aws_hook_mock, **kwargs):
},
'propagate_tags': 'TASK_DEFINITION',
}
self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs)
self.ecs = EcsOperator(**self.ecs_operator_args, **kwargs)
self.ecs.get_hook()

def setUp(self):
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_template_fields_overrides(self):
],
]
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
def test_execute_without_failures(
self,
launch_type,
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_execute_with_failures(self):
resp_failures['failures'].append('dummy error')
client_mock.run_task.return_value = resp_failures

with pytest.raises(ECSOperatorError):
with pytest.raises(EcsOperatorError):
self.ecs.execute(None)

self.aws_hook_mock.return_value.get_conn.assert_called_once()
Expand Down Expand Up @@ -429,15 +429,15 @@ def test_check_success_task_not_raises(self):
['', {'testTagKey': 'testTagValue'}],
]
)
@mock.patch.object(ECSOperator, "_xcom_del")
@mock.patch.object(EcsOperator, "_xcom_del")
@mock.patch.object(
ECSOperator,
EcsOperator,
"xcom_pull",
return_value="arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(ECSOperator, '_start_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_start_task')
def test_reattach_successful(
self, launch_type, tags, start_mock, check_mock, wait_mock, xcom_pull_mock, xcom_del_mock
):
Expand Down Expand Up @@ -487,11 +487,11 @@ def test_reattach_successful(
['', {'testTagKey': 'testTagValue'}],
]
)
@mock.patch.object(ECSOperator, '_xcom_del')
@mock.patch.object(ECSOperator, '_xcom_set')
@mock.patch.object(ECSOperator, '_try_reattach_task')
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_xcom_del')
@mock.patch.object(EcsOperator, '_xcom_set')
@mock.patch.object(EcsOperator, '_try_reattach_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
def test_reattach_save_task_arn_xcom(
self, launch_type, tags, check_mock, wait_mock, reattach_mock, xcom_set_mock, xcom_del_mock
):
Expand Down Expand Up @@ -552,18 +552,18 @@ def test_execute_xcom_disabled(self):

class TestShouldRetry(unittest.TestCase):
def test_return_true_on_valid_reason(self):
self.assertTrue(should_retry(ECSOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))
self.assertTrue(should_retry(EcsOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))

def test_return_false_on_invalid_reason(self):
self.assertFalse(should_retry(ECSOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))
self.assertFalse(should_retry(EcsOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))


class TestECSTaskLogFetcher(unittest.TestCase):
class TestEcsTaskLogFetcher(unittest.TestCase):
@mock.patch('logging.Logger')
def set_up_log_fetcher(self, logger_mock):
self.logger_mock = logger_mock

self.log_fetcher = ECSTaskLogFetcher(
self.log_fetcher = EcsTaskLogFetcher(
log_group="test_log_group",
log_stream_name="test_log_stream_name",
fetch_interval=timedelta(milliseconds=1),
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_ecs_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.backend("postgres", "mysql")
class ECSSystemTest(AmazonSystemTest):
class EcsSystemTest(AmazonSystemTest):
"""
ECS System Test to run and test example ECS dags
Expand Down

0 comments on commit 9c0ba1b

Please sign in to comment.