Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ def _process_executor_events(self, session: Session) -> int:

if state == TaskInstanceState.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
# TODO don't actually commit this log change
self.log.warning("Setting external_id for %s to %s", ti, info)
continue

msg = (
Expand Down
69 changes: 66 additions & 3 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -47,12 +47,13 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
ExecutorConfigType,
Expand Down Expand Up @@ -110,6 +111,11 @@ def __init__(self, *args, **kwargs):
self.IS_BOTO_CONNECTION_HEALTHY = False

self.run_task_kwargs = self._load_run_kwargs()
self.adopt_task_instances = conf.getboolean(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.ADOPT_TASK_INSTANCES,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.ADOPT_TASK_INSTANCES],
)

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
Expand Down Expand Up @@ -393,6 +399,9 @@ def attempt_task_runs(self):
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
# Add Fargate task ARN to executor event buffer, which gets saved
# in TaskInstance.external_executor_id.
self.event_buffer[task_key] = (State.QUEUED, task.task_arn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have self.success and self.fail (e.g. link) which handles the other state changes. These are provided by the base executor. I wonder if we should add a new method for putting a task in queued state 🤔 but this might cause weird issues if a provider which contains an executor is installed alongside an older version of airflow... (also, not required for this PR, just thinking out loud)

if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -444,6 +453,10 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None,

def end(self, heartbeat_interval=10):
"""Wait for all currently running tasks to end, and don't launch any tasks."""
if self.adopt_task_instances:
self.log.info("Task adoption is enabled, not terminating tasks.")
return

try:
while True:
self.sync()
Expand All @@ -456,7 +469,11 @@ def end(self, heartbeat_interval=10):
self.log.exception("Failed to end %s", self.__class__.__name__)

def terminate(self):
"""Kill all ECS processes by calling Boto3's StopTask API."""
"""Kill all ECS processes by calling Boto3's StopTask API if adopt_task_instances option is False."""
if self.adopt_task_instances:
self.log.info("Task adoption is enabled, not terminating tasks.")
return

try:
for arn in self.active_workers.get_all_arns():
self.ecs.stop_task(
Expand Down Expand Up @@ -493,3 +510,49 @@ def get_container(self, container_list):
'container "name" must be provided in "containerOverrides" configuration'
)
raise KeyError(f"No such container found by container name: {self.container_name}")

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Try to adopt running task instances if adopt_task_instances option is set to True.

These tasks instances should have an ECS process which can be adopted by the unique task ARN.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it expected that there is no difference for the adoption process between a completed/terminated state to a running state of the instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right. For comparison,. k8s executor filters out "status == successful", but local just passes them all, and celery appears to print the state for each but isn't filtering.

Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
if not self.adopt_task_instances:
# Do not try to adopt task instances, return all orphaned tasks for clearing.
return tis

self.log.warning("Task adoption is on. ECS Executor attempting to adopt tasks.")

with Stats.timer("ecs_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []

from pprint import pformat
self.log.warning("tis: \n%s", pformat([vars(ti) for ti in tis]))

if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
task_descriptions = self.__describe_tasks(task_arns).get("tasks", [])

for task in task_descriptions:
ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
self.active_workers.add_task(
task,
ti.key,
ti.queue,
ti.command_as_list(),
ti.executor_config,
ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
10 changes: 6 additions & 4 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"assign_public_ip": "False",
"platform_version": "LATEST",
"check_health_on_startup": "True",
"adopt_task_instances": "False",
}


Expand Down Expand Up @@ -84,22 +85,23 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys):
ASSIGN_PUBLIC_IP = "assign_public_ip"
CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
CLUSTER = "cluster"
CONTAINER_NAME = "container_name"
LAUNCH_TYPE = "launch_type"
PLATFORM_VERSION = "platform_version"
SECURITY_GROUPS = "security_groups"
SUBNETS = "subnets"
TASK_DEFINITION = "task_definition"
CONTAINER_NAME = "container_name"


class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
"""All keys loaded into the config which are related to the ECS Executor."""

MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
ADOPT_TASK_INSTANCES = "adopt_task_instances"
AWS_CONN_ID = "conn_id"
RUN_TASK_KWARGS = "run_task_kwargs"
REGION_NAME = "region_name"
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
REGION_NAME = "region_name"
RUN_TASK_KWARGS = "run_task_kwargs"


class EcsExecutorException(Exception):
Expand Down
11 changes: 11 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,17 @@ config:
type: boolean
example: "True"
default: "True"
adopt_task_instances:
description: |
If True, the executor will try to adopt orphaned task instances from a SchedulerJob
shutdown event, for example when a scheduler container is re-deployed or terminated.
If False, the executor will terminate all active AWS ECS Tasks when the scheduler shuts down.
More documentation can be found in the Airflow docs:
``https://airflow.apache.org/docs/apache-airflow/stable/scheduler.html#scheduler-tuneables``.
version_added: "8.17"
type: boolean
example: "True"
default: "False"
aws_auth_manager:
description: |
This section only applies if you are using the AwsAuthManager. In other words, if you set
Expand Down
69 changes: 69 additions & 0 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import (
CONFIG_GROUP_NAME,
Expand Down Expand Up @@ -829,6 +831,72 @@ def test_update_running_tasks_failed(self, mock_executor, caplog):
"test failure" in caplog.messages[0]
)

def test_terminate_with_task_adoption(self, mock_executor):
"""Test that executor does not shut down active ECS tasks when adopt_task_instances is True."""
mock_executor.adopt_task_instances = True
mock_executor.terminate()

# tasks are not terminated
mock_executor.ecs.stop_task.assert_not_called()

def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.adopt_task_instances = True

mock_executor.ecs.describe_tasks.return_value = {
"tasks": [
{
"taskArn": "001",
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "some-ecs-container"}],
},
{
"taskArn": "002",
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "another-ecs-container"}],
},
],
"failures": [],
}

orphaned_tasks = [
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
]
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:
task.prev_attempted_tries = 1

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

mock_executor.ecs.describe_tasks.assert_called_once()
# Two of the three tasks should be adopted.
assert len(orphaned_tasks) - 1 == len(mock_executor.active_workers)
# The remaining one task is unable to be adopted.
assert 1 == len(not_adopted_tasks)

def test_try_adopt_task_instances_disabled(self, mock_executor):
"""Test that executor won't adopt orphaned task instances if adopt_task_instances is False."""
orphaned_tasks = [
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
]

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

# Ensure that describe_tasks is not called.
mock_executor.ecs.describe_tasks.assert_not_called()
# No orphaned tasks are stored in active workers.
assert len(mock_executor.active_workers) == 0
# All tasks are unable to be adopted.
assert len(orphaned_tasks) == len(not_adopted_tasks)


class TestEcsExecutorConfig:
@pytest.fixture()
Expand Down Expand Up @@ -899,6 +967,7 @@ def test_config_defaults_are_applied(self, assign_subnets):
AllEcsConfigKeys.AWS_CONN_ID,
AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP,
AllEcsConfigKeys.ADOPT_TASK_INSTANCES,
]:
assert expected_key not in found_keys.keys()
else:
Expand Down