Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import sys
import traceback
import warnings
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Collection, Mapping, MutableMapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -320,14 +320,14 @@ class BulkStateFetcher(LoggingMixin):
Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually.
"""

def __init__(self, sync_parallelism=None):
def __init__(self, sync_parallelism: int):
super().__init__()
self._sync_parallelism = sync_parallelism

def _tasks_list_to_task_ids(self, async_tasks) -> set[str]:
def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) -> set[str]:
return {a.task_id for a in async_tasks}

def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
def get_many(self, async_results: Collection[AsyncResult]) -> Mapping[str, EventBufferValueType]:
"""Get status for many Celery tasks using the best method available."""
if isinstance(app.backend, BaseKeyValueStoreBackend):
result = self._get_many_from_kv_backend(async_results)
Expand All @@ -338,7 +338,9 @@ def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
self.log.debug("Fetched %d state(s) for %d task(s)", len(result), len(async_results))
return result

def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
def _get_many_from_kv_backend(
self, async_tasks: Collection[AsyncResult]
) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
keys = [app.backend.get_key_for_task(k) for k in task_ids]
values = app.backend.mget(keys)
Expand All @@ -348,13 +350,15 @@ def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValu
return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)

@retry
def _query_task_cls_from_db_backend(self, task_ids, **kwargs):
def _query_task_cls_from_db_backend(self, task_ids: set[str], **kwargs):
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
return session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()

def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
def _get_many_from_db_backend(
self, async_tasks: Collection[AsyncResult]
) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
tasks = self._query_task_cls_from_db_backend(task_ids)
task_results = [app.backend.meta_from_decoded(task.to_dict()) for task in tasks]
Expand All @@ -364,21 +368,23 @@ def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValu

@staticmethod
def _prepare_state_and_info_by_task_dict(
task_ids, task_results_by_task_id
task_ids: set[str], task_results_by_task_id: dict[str, dict[str, Any]]
) -> Mapping[str, EventBufferValueType]:
state_info: MutableMapping[str, EventBufferValueType] = {}
for task_id in task_ids:
task_result = task_results_by_task_id.get(task_id)
if task_result:
state = task_result["status"]
info = None if not hasattr(task_result, "info") else task_result["info"]
info = task_result.get("info")
else:
state = celery_states.PENDING
info = None
state_info[task_id] = state, info
return state_info

def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBufferValueType]:
def _get_many_using_multiprocessing(
self, async_results: Collection[AsyncResult]
) -> Mapping[str, EventBufferValueType]:
num_process = min(len(async_results), self._sync_parallelism)

with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test_should_support_kv_backend(self, mock_mget, caplog):
"airflow.providers.celery.executors.celery_executor_utils.Celery.backend", mock_backend
):
caplog.clear()
fetcher = celery_executor_utils.BulkStateFetcher()
fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_should_support_db_backend(self, mock_session, caplog):
mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
]

fetcher = celery_executor_utils.BulkStateFetcher()
fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
Expand Down Expand Up @@ -401,7 +401,7 @@ def test_should_retry_db_backend(self, mock_session, caplog):
mock_retry_db_result.return_value,
]

fetcher = celery_executor_utils.BulkStateFetcher()
fetcher = celery_executor_utils.BulkStateFetcher(1)
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
Expand Down
Loading