Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions airflow/providers/amazon/aws/executors/batch/batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

from __future__ import annotations

import contextlib
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -34,11 +35,12 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
BatchSubmitJobResponseSchema,
Expand Down Expand Up @@ -306,14 +308,20 @@ def attempt_submit_jobs(self):
self.pending_jobs.append(batch_job)
else:
# Success case
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=submit_job_response["job_id"],
job_id=job_id,
airflow_task_key=key,
airflow_cmd=cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
with contextlib.suppress(AttributeError):
# TODO: Remove this when min_airflow_version is 2.10.0 or higher in Amazon provider.
# running_state is added in Airflow 2.10 and only needed to support task adoption
# (an optional executor feature).
self.running_state(key, job_id)
if failure_reasons:
self.log.error(
"Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -418,3 +426,39 @@ def _load_submit_kwargs() -> dict:
" and value should be NULL or empty."
)
return submit_kwargs

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the Batch job ID).

Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("batch_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []

if job_ids := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
batch_jobs = self._describe_jobs(job_ids)

for batch_job in batch_jobs:
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_task_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
attempt_number=ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
29 changes: 29 additions & 0 deletions tests/providers/amazon/aws/executors/batch/test_batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.executors.batch import batch_executor, batch_executor_config
from airflow.providers.amazon.aws.executors.batch.batch_executor import (
AwsBatchExecutor,
Expand Down Expand Up @@ -615,6 +616,34 @@ def _mock_sync(
}
executor.batch.describe_jobs.return_value = {"jobs": [after_batch_job]}

def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.batch.describe_jobs.return_value = {
"jobs": [
{"jobId": "001", "status": "SUCCEEDED"},
{"jobId": "002", "status": "SUCCEEDED"},
],
}

orphaned_tasks = [
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
]
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:
task.try_number = 1

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

mock_executor.batch.describe_jobs.assert_called_once()
# Two of the three tasks should be adopted.
assert len(orphaned_tasks) - 1 == len(mock_executor.active_workers)
# The remaining one task is unable to be adopted.
assert 1 == len(not_adopted_tasks)


class TestBatchExecutorConfig:
@staticmethod
Expand Down