Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, NoCredentialsError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -42,6 +42,9 @@
EcsQueuedTask,
EcsTaskCollection,
)
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils import timezone
from airflow.utils.state import State

if TYPE_CHECKING:
Expand All @@ -51,6 +54,12 @@
ExecutorConfigType,
)

INVALID_CREDENTIALS_EXCEPTIONS = [
"ExpiredTokenException",
"InvalidClientTokenId",
"UnrecognizedClientException",
]


class AwsEcsExecutor(BaseExecutor):
"""
Expand Down Expand Up @@ -91,30 +100,15 @@ def __init__(self, *args, **kwargs):

self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
aws_conn_id = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
)
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
self.attempts_since_last_successful_connection = 0

self.load_ecs_connection(check_connection=False)
self.IS_BOTO_CONNECTION_HEALTHY = False

self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.run_task_kwargs = self._load_run_kwargs()

def start(self):
"""
Make a test API call to check the health of the ECS Executor.

Deliberately use an invalid task ID, some potential outcomes in order:
1. "AccessDeniedException" is raised if there are insufficient permissions.
2. "ClusterNotFoundException" is raised if permissions exist but the cluster does not.
3. The API responds with a failure message if the cluster is found and there
are permissions, but the cluster itself has issues.
4. "InvalidParameterException" is raised if the permissions and cluster exist but the task does not.

The last one is considered a success state for the purposes of this check.
"""
"""This is called by the scheduler when the Executor is being run for the first time."""
check_health = conf.getboolean(
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
)
Expand All @@ -123,7 +117,25 @@ def start(self):
return

self.log.info("Starting ECS Executor and determining health...")
try:
self.check_health()
except AirflowException:
self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
raise

def check_health(self):
"""
Make a test API call to check the health of the ECS Executor.

Deliberately use an invalid task ID, some potential outcomes in order:
1. `AccessDeniedException` is raised if there are insufficient permissions.
2. `ClusterNotFoundException` is raised if permissions exist but the cluster does not.
3. The API responds with a failure message if the cluster is found and there
are permissions, but the cluster itself has issues.
4. `InvalidParameterException` is raised if the permissions and cluster exist but the task does not.

The last one is considered a success state for the purposes of this check.
"""
success_status = "succeeded."
status = success_status

Expand Down Expand Up @@ -151,18 +163,50 @@ def start(self):
finally:
msg_prefix = "ECS Executor health check has %s"
if status == success_status:
self.IS_BOTO_CONNECTION_HEALTHY = True
self.log.info(msg_prefix, status)
else:
msg_error_suffix = (
"The ECS executor will not be able to run Airflow tasks until the issue is addressed. "
"Stopping the Airflow Scheduler from starting until the issue is resolved."
"The ECS executor will not be able to run Airflow tasks until the issue is addressed."
)
raise AirflowException(msg_prefix % status + msg_error_suffix)

def load_ecs_connection(self, check_connection: bool = True):
self.log.info("Loading Connection information")
aws_conn_id = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
)
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.attempts_since_last_successful_connection += 1
self.last_connection_reload = timezone.utcnow()

if check_connection:
self.check_health()
self.attempts_since_last_successful_connection = 0

def sync(self):
if not self.IS_BOTO_CONNECTION_HEALTHY:
exponential_backoff_retry(
self.last_connection_reload,
self.attempts_since_last_successful_connection,
self.load_ecs_connection,
)
if not self.IS_BOTO_CONNECTION_HEALTHY:
return
try:
self.sync_running_tasks()
self.attempt_task_runs()
except (ClientError, NoCredentialsError) as error:
error_code = error.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.IS_BOTO_CONNECTION_HEALTHY = False
self.log.warning(
f"AWS credentials are either missing or expired: {error}.\nRetrying connection"
)

except Exception:
# We catch any and all exceptions because otherwise they would bubble
# up and kill the scheduler process
Expand All @@ -176,6 +220,7 @@ def sync_running_tasks(self):
return

describe_tasks_response = self.__describe_tasks(all_task_arns)

self.log.debug("Active Workers: %s", describe_tasks_response)

if describe_tasks_response["failures"]:
Expand Down Expand Up @@ -288,6 +333,15 @@ def attempt_task_runs(self):
_failure_reasons = []
try:
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
except NoCredentialsError:
self.pending_tasks.appendleft(ecs_task)
raise
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.appendleft(ecs_task)
raise
_failure_reasons.append(str(e))
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/amazon/aws/executors/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Callable

from airflow.utils import timezone

log = logging.getLogger(__name__)


def exponential_backoff_retry(
last_attempt_time: datetime,
attempts_since_last_successful: int,
callable_function: Callable,
max_delay: int = 60 * 2,
max_attempts: int = -1,
exponent_base: int = 4,
) -> None:
"""
Retries a callable function with exponential backoff between attempts if it raises an exception.

:param last_attempt_time: Timestamp of last attempt call.
:param attempts_since_last_successful: Number of attempts since last success.
:param callable_function: Callable function that will be called if enough time has passed.
:param max_delay: Maximum delay in seconds between retries. Default 120.
:param max_attempts: Maximum number of attempts before giving up. Default -1 (no limit).
:param exponent_base: Exponent base to calculate delay. Default 4.
"""
if max_attempts != -1 and attempts_since_last_successful >= max_attempts:
log.error("Max attempts reached. Exiting.")
return

delay = min((exponent_base**attempts_since_last_successful), max_delay)
next_retry_time = last_attempt_time + timedelta(seconds=delay)
current_time = timezone.utcnow()

if current_time >= next_retry_time:
try:
callable_function()
except Exception:
log.exception("Error calling %r", callable_function.__name__)
next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
log.info("Waiting for %s seconds before retrying.", next_delay)
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ As a final step, access to the database must be configured for the ECS container

3. Select the security group associated with your RDS instance, and click Edit inbound rules.

4. Add a new rule that allows PostgreSQL type traffic to the CIDR of the subnet(s) associated with the DB.
4. Add a new rule that allows PostgreSQL type traffic to the CIDR of the subnet(s) associated with the Ecs cluster.

Configure Airflow
~~~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def set_env_vars():
def mock_executor(set_env_vars) -> AwsEcsExecutor:
"""Mock ECS to a repeatable starting state.."""
executor = AwsEcsExecutor()
executor.IS_BOTO_CONNECTION_HEALTHY = True

# Replace boto3 ECS client with mock.
ecs_mock = mock.Mock(spec=executor.ecs)
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/executors/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading