Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize AWS ECS naming #20332

Merged
merged 2 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

"""
This is an example dag for ECSOperator.
This is an example dag for EcsOperator.

The task "hello_world" runs `hello-world` task in `c` cluster.
It overrides the command in the `hello-world-container` container.
Expand All @@ -26,7 +26,7 @@
import os

from airflow import DAG
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
from airflow.providers.amazon.aws.operators.ecs import EcsOperator

dag = DAG(
dag_id="ecs_fargate_dag",
Expand All @@ -40,7 +40,7 @@
dag.doc_md = __doc__

# [START howto_operator_ecs]
hello_world = ECSOperator(
hello_world = EcsOperator(
task_id="hello_world",
dag=dag,
aws_conn_id="aws_ecs",
Expand Down
19 changes: 18 additions & 1 deletion airflow/providers/amazon/aws/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,29 @@
#
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state
import warnings


class ECSOperatorError(Exception):
class EcsOperatorError(Exception):
"""Raise when ECS cannot handle the request."""

def __init__(self, failures: list, message: str):
self.failures = failures
self.message = message
super().__init__(message)
eladkal marked this conversation as resolved.
Show resolved Hide resolved


class ECSOperatorError(EcsOperatorError):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.exceptions.EcsOperatorError`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.exceptions.EcsOperatorError`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
75 changes: 62 additions & 13 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import sys
import time
import warnings
from collections import deque
from datetime import datetime, timedelta
from logging import Logger
Expand All @@ -29,7 +30,7 @@

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, XCom
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.exceptions import EcsOperatorError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.typing_compat import Protocol, runtime_checkable
Expand All @@ -38,7 +39,7 @@

def should_retry(exception: Exception):
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
if isinstance(exception, ECSOperatorError):
if isinstance(exception, EcsOperatorError):
return any(
quota_reason in failure['reason']
for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
Expand All @@ -48,10 +49,10 @@ def should_retry(exception: Exception):


@runtime_checkable
class ECSProtocol(Protocol):
class EcsProtocol(Protocol):
"""
A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on
:py:meth:`.ECSOperator.client`.
:py:meth:`.EcsOperator.client`.

.. seealso::

Expand Down Expand Up @@ -84,7 +85,7 @@ def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family:
...


class ECSTaskLogFetcher(Thread):
class EcsTaskLogFetcher(Thread):
"""
Fetches Cloudwatch log events with specific interval as a thread
and sends the log events to the info channel of the provided logger.
Expand Down Expand Up @@ -151,13 +152,13 @@ def stop(self):
self._event.set()


class ECSOperator(BaseOperator):
class EcsOperator(BaseOperator):
"""
Execute a task on AWS ECS (Elastic Container Service)

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:ECSOperator`
:ref:`howto/operator:EcsOperator`

:param task_definition: the task definition name on Elastic Container Service
:type task_definition: str
Expand Down Expand Up @@ -289,17 +290,17 @@ def __init__(
self.awslogs_region = region_name

self.hook: Optional[AwsBaseHook] = None
self.client: Optional[ECSProtocol] = None
self.client: Optional[EcsProtocol] = None
self.arn: Optional[str] = None
self.retry_args = quota_retry
self.task_log_fetcher: Optional[ECSTaskLogFetcher] = None
self.task_log_fetcher: Optional[EcsTaskLogFetcher] = None

@provide_session
def execute(self, context, session=None):
self.log.info(
'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster
)
self.log.info('ECSOperator overrides: %s', self.overrides)
self.log.info('EcsOperator overrides: %s', self.overrides)

self.client = self.get_hook().get_conn()

Expand Down Expand Up @@ -371,7 +372,7 @@ def _start_task(self, context):

failures = response['failures']
if len(failures) > 0:
raise ECSOperatorError(failures, response)
raise EcsOperatorError(failures, response)
self.log.info('ECS Task started: %s', response)

self.arn = response['tasks'][0]['taskArn']
Expand Down Expand Up @@ -430,11 +431,12 @@ def _wait_for_task_ended(self) -> None:
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _get_task_log_fetcher(self) -> ECSTaskLogFetcher:
def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
return ECSTaskLogFetcher(

return EcsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
Expand Down Expand Up @@ -509,3 +511,50 @@ def on_kill(self) -> None:
cluster=self.cluster, task=self.arn, reason='Task killed by the user'
)
self.log.info(response)


class ECSOperator(EcsOperator):
"""
This operator is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsOperator`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This operator is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsOperator`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class ECSTaskLogFetcher(EcsTaskLogFetcher):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class ECSProtocol(EcsProtocol):
"""
This class is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.
"""

def __init__(self):
warnings.warn(
"This class is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.",
DeprecationWarning,
stacklevel=2,
)
6 changes: 3 additions & 3 deletions docs/apache-airflow-providers-amazon/operators/ecs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
under the License.


.. _howto/operator:ECSOperator:
.. _howto/operator:EcsOperator:

ECS Operator
============
Expand All @@ -30,14 +30,14 @@ Using Operator
--------------

Use the
:class:`~airflow.providers.amazon.aws.operators.ecs.ECSOperator`
:class:`~airflow.providers.amazon.aws.operators.ecs.EcsOperator`
to run a task defined in AWS ECS.

In the following example,
the task "hello_world" runs ``hello-world`` task in ``c`` cluster.
It overrides the command in the ``hello-world-container`` container.

Before using ECSOperator, *cluster* and *task definition* need to be created.
Before using EcsOperator, *cluster* and *task definition* need to be created.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
:language: python
Expand Down
42 changes: 21 additions & 21 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSTaskLogFetcher, should_retry
from airflow.providers.amazon.aws.exceptions import EcsOperatorError
from airflow.providers.amazon.aws.operators.ecs import EcsOperator, EcsTaskLogFetcher, should_retry

# fmt: off
RESPONSE_WITHOUT_FAILURES = {
Expand All @@ -55,7 +55,7 @@
# fmt: on


class TestECSOperator(unittest.TestCase):
class TestEcsOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsBaseHook')
def set_up_operator(self, aws_hook_mock, **kwargs):
self.aws_hook_mock = aws_hook_mock
Expand All @@ -77,7 +77,7 @@ def set_up_operator(self, aws_hook_mock, **kwargs):
},
'propagate_tags': 'TASK_DEFINITION',
}
self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs)
self.ecs = EcsOperator(**self.ecs_operator_args, **kwargs)
self.ecs.get_hook()

def setUp(self):
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_template_fields_overrides(self):
],
]
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
def test_execute_without_failures(
self,
launch_type,
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_execute_with_failures(self):
resp_failures['failures'].append('dummy error')
client_mock.run_task.return_value = resp_failures

with pytest.raises(ECSOperatorError):
with pytest.raises(EcsOperatorError):
self.ecs.execute(None)

self.aws_hook_mock.return_value.get_conn.assert_called_once()
Expand Down Expand Up @@ -409,15 +409,15 @@ def test_check_success_task_not_raises(self):
['', {'testTagKey': 'testTagValue'}],
]
)
@mock.patch.object(ECSOperator, "_xcom_del")
@mock.patch.object(EcsOperator, "_xcom_del")
@mock.patch.object(
ECSOperator,
EcsOperator,
"xcom_pull",
return_value="arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(ECSOperator, '_start_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_start_task')
def test_reattach_successful(
self, launch_type, tags, start_mock, check_mock, wait_mock, xcom_pull_mock, xcom_del_mock
):
Expand Down Expand Up @@ -467,11 +467,11 @@ def test_reattach_successful(
['', {'testTagKey': 'testTagValue'}],
]
)
@mock.patch.object(ECSOperator, '_xcom_del')
@mock.patch.object(ECSOperator, '_xcom_set')
@mock.patch.object(ECSOperator, '_try_reattach_task')
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
@mock.patch.object(EcsOperator, '_xcom_del')
@mock.patch.object(EcsOperator, '_xcom_set')
@mock.patch.object(EcsOperator, '_try_reattach_task')
@mock.patch.object(EcsOperator, '_wait_for_task_ended')
@mock.patch.object(EcsOperator, '_check_success_task')
def test_reattach_save_task_arn_xcom(
self, launch_type, tags, check_mock, wait_mock, reattach_mock, xcom_set_mock, xcom_del_mock
):
Expand Down Expand Up @@ -532,18 +532,18 @@ def test_execute_xcom_disabled(self):

class TestShouldRetry(unittest.TestCase):
def test_return_true_on_valid_reason(self):
self.assertTrue(should_retry(ECSOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))
self.assertTrue(should_retry(EcsOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))

def test_return_false_on_invalid_reason(self):
self.assertFalse(should_retry(ECSOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))
self.assertFalse(should_retry(EcsOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))


class TestECSTaskLogFetcher(unittest.TestCase):
class TestEcsTaskLogFetcher(unittest.TestCase):
@mock.patch('logging.Logger')
def set_up_log_fetcher(self, logger_mock):
self.logger_mock = logger_mock

self.log_fetcher = ECSTaskLogFetcher(
self.log_fetcher = EcsTaskLogFetcher(
log_group="test_log_group",
log_stream_name="test_log_stream_name",
fetch_interval=timedelta(milliseconds=1),
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_ecs_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.backend("postgres", "mysql")
class ECSSystemTest(AmazonSystemTest):
class EcsSystemTest(AmazonSystemTest):
"""
ECS System Test to run and test example ECS dags

Expand Down