Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations

Expand Down Expand Up @@ -189,8 +189,32 @@ def filter(self, record):
@_providers_configuration_loaded
def worker(args):
"""Start Airflow Celery worker."""
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app
team_config = None
if hasattr(args, "team") and args.team:
# Multi-team is enabled, create team-specific Celery app and use team based config
# This requires Airflow 3.2+, and core.multi_team config to be true to be enabled.
if not AIRFLOW_V_3_2_PLUS:
raise SystemExit(
"Error: Multi-team Celery workers require Airflow version 3.2 or higher. "
"Please upgrade your Airflow installation or remove the --team argument."
)
if not conf.getboolean("core", "multi_team", fallback=False):
raise SystemExit(
"Error: Multi-team Celery workers require core.multi_team configuration to be enabled. "
"Please enable core.multi_team in your Airflow config or remove the --team argument."
)
from airflow.executors.base_executor import ExecutorConf
from airflow.providers.celery.executors.celery_executor_utils import create_celery_app

team_config = ExecutorConf(team_name=args.team)
log.info("Starting Celery worker for team: %s", args.team)
celery_app = create_celery_app(team_config)
else:
# Backward compatible: use module-level app with global config
from airflow.providers.celery.executors.celery_executor import app as celery_app

# Use team_config for config reads in multi-team mode, otherwise use global conf
config = team_config if team_config else conf

# Check if a worker with the same hostname already exists
if args.celery_hostname:
Expand Down Expand Up @@ -218,8 +242,8 @@ def worker(args):
autoscale = args.autoscale
skip_serve_logs = args.skip_serve_logs

if autoscale is None and conf.has_option("celery", "worker_autoscale"):
autoscale = conf.get("celery", "worker_autoscale")
if autoscale is None and config.has_option("celery", "worker_autoscale"):
autoscale = config.get("celery", "worker_autoscale")

if hasattr(celery_app.backend, "ResultSession"):
# Pre-create the database tables now, otherwise SQLA via Celery has a
Expand All @@ -238,9 +262,9 @@ def worker(args):
pass

# backwards-compatible: https://github.com/apache/airflow/pull/21506#pullrequestreview-879893763
celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL")
celery_log_level = config.get("logging", "CELERY_LOGGING_LEVEL")
if not celery_log_level:
celery_log_level = conf.get("logging", "LOGGING_LEVEL")
celery_log_level = config.get("logging", "LOGGING_LEVEL")

# Setup Celery worker
options = [
Expand All @@ -263,8 +287,8 @@ def worker(args):
if args.without_gossip:
options.append("--without-gossip")

if conf.has_option("celery", "pool"):
pool = conf.get("celery", "pool")
if config.has_option("celery", "pool"):
pool = config.get("celery", "pool")
options.extend(["--pool", pool])
# Celery pools of type eventlet and gevent use greenlets, which
# requires monkey patching the app:
Expand All @@ -288,7 +312,7 @@ def run_celery_worker():
if args.umask:
umask = args.umask
else:
umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK)
umask = config.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK)

_run_command_with_daemon_option(
args=args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@
help="Don't subscribe to other workers events",
action="store_true",
)
ARG_TEAM = Arg(
("-t", "--team"),
help="Team name for team-specific multi-team configuration (requires Airflow 3.2+)",
)
ARG_OUTPUT = Arg(
(
"-o",
Expand Down Expand Up @@ -139,6 +143,7 @@
ARG_QUEUES,
ARG_CONCURRENCY,
ARG_CELERY_HOSTNAME,
ARG_TEAM,
ARG_PID,
ARG_DAEMON,
ARG_UMASK,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from celery import states as celery_states
from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
Expand Down Expand Up @@ -105,18 +104,34 @@ class CeleryExecutor(BaseExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Check if self has the ExecutorConf set on the self.conf attribute with all required methods.
# In Airflow 2.x, ExecutorConf exists but lacks methods like getint, getboolean, getsection, etc.
# In such cases, fall back to the global configuration object.
# This allows the changes to be backwards compatible with older versions of Airflow.
# Can be removed when minimum supported provider version is equal to the version of core airflow
# which introduces multi-team configuration (3.2+).
if not hasattr(self, "conf") or not hasattr(self.conf, "getint"):
from airflow.configuration import conf as global_conf

self.conf = global_conf

# Create Celery app, it will be team specific if the configuration has been set for that.
from airflow.providers.celery.executors.celery_executor_utils import create_celery_app

self.celery_app = create_celery_app(self.conf)

# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
# so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = conf.getint("celery", "SYNC_PARALLELISM")
self._sync_parallelism = self.conf.getint("celery", "SYNC_PARALLELISM", fallback=0)
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)
from airflow.providers.celery.executors.celery_executor_utils import BulkStateFetcher

self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, celery_app=self.celery_app)
self.tasks = {}
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_max_retries = conf.getint("celery", "task_publish_max_retries")
self.task_publish_max_retries = self.conf.getint("celery", "task_publish_max_retries", fallback=3)

def start(self) -> None:
self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism)
Expand All @@ -131,19 +146,17 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int:

def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
# Airflow V2 version
from airflow.providers.celery.executors.celery_executor_utils import execute_command

task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples]
task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for task_tuple in task_tuples]

self._send_tasks(task_tuples_to_send)

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
# Airflow V3 version -- have to delay imports until we know we are on v3
from airflow.executors.workloads import ExecuteTask
from airflow.providers.celery.executors.celery_executor_utils import execute_workload

tasks = [
(workload.ti.key, workload, workload.ti.queue, execute_workload)
(workload.ti.key, workload, workload.ti.queue, self.team_name)
for workload in workloads
if isinstance(workload, ExecuteTask)
]
Expand All @@ -154,11 +167,9 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
self._send_tasks(tasks)

def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]):
first_task = next(t[-1] for t in task_tuples_to_send)

# Celery state queries will be stuck if we do not use one same backend
# for all tasks.
cached_celery_backend = first_task.backend
cached_celery_backend = self.celery_app.backend

key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
self.log.debug("Sent all tasks.")
Expand Down Expand Up @@ -206,6 +217,8 @@ def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[TaskInstanceInCele
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)

# Use ProcessPoolExecutor with team_name instead of task objects to avoid pickling issues.
# Subprocesses reconstruct the team-specific Celery app from the team name and existing config.
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
Expand Down Expand Up @@ -343,12 +356,10 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
return reprs

def revoke_task(self, *, ti: TaskInstance):
from airflow.providers.celery.executors.celery_executor_utils import app

celery_async_result = self.tasks.pop(ti.key, None)
if celery_async_result:
try:
app.control.revoke(celery_async_result.task_id)
self.celery_app.control.revoke(celery_async_result.task_id)
except Exception:
self.log.exception("Error revoking task instance %s from celery", ti.key)
self.running.discard(ti.key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
from functools import cache
from typing import TYPE_CHECKING, Any

from celery import Celery, Task, states as celery_states
from celery import Celery, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
from celery.backends.database import DatabaseBackend, Task as TaskDb, retry, session_cleanup
from celery.signals import import_modules as celery_import_modules
from sqlalchemy import select

from airflow.configuration import conf
from airflow.configuration import AirflowConfigParser, conf
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, timeout
Expand All @@ -66,14 +66,16 @@
from celery.result import AsyncResult

from airflow.executors import workloads
from airflow.executors.base_executor import EventBufferValueType
from airflow.executors.base_executor import EventBufferValueType, ExecutorConf
from airflow.models.taskinstance import TaskInstanceKey

# We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define
# the type as the union of both kinds
CommandType = Sequence[str]

TaskInstanceInCelery: TypeAlias = tuple[TaskInstanceKey, workloads.All | CommandType, str | None, Task]
TaskInstanceInCelery: TypeAlias = tuple[
TaskInstanceKey, workloads.All | CommandType, str | None, str | None
]

TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]

Expand Down Expand Up @@ -102,6 +104,38 @@ def _get_celery_app() -> Celery:
return Celery(celery_app_name, config_source=get_celery_configuration())


def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery:
"""
Create a Celery app, supporting team-specific configuration.

:param team_conf: ExecutorConf instance with team-specific configuration, or global conf
:return: Celery app instance
"""
from airflow.providers.celery.executors.default_celery import get_default_celery_config

celery_app_name = team_conf.get("celery", "CELERY_APP_NAME")

# Make app name unique per team to ensure proper broker isolation
# Each team's executor needs a distinct Celery app name to prevent
# tasks from being routed to the wrong broker
# Only do this if team_conf is an ExecutorConf with team_name (not global conf)
team_name = getattr(team_conf, "team_name", None)
if team_name:
celery_app_name = f"{celery_app_name}_{team_name}"

config = get_default_celery_config(team_conf)

celery_app = Celery(celery_app_name, config_source=config)

# Register tasks with this app
celery_app.task(name="execute_workload")(execute_workload)
if not AIRFLOW_V_3_0_PLUS:
celery_app.task(name="execute_command")(execute_command)

return celery_app


# Keep module-level app for backward compatibility
app = _get_celery_app()


Expand Down Expand Up @@ -272,14 +306,39 @@ def __init__(self, exception: BaseException, exception_traceback: str):
def send_task_to_executor(
task_tuple: TaskInstanceInCelery,
) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]:
"""Send task to executor."""
key, args, queue, task_to_run = task_tuple
"""
Send task to executor.

This function is called in ProcessPoolExecutor subprocesses. To avoid pickling issues with
team-specific Celery apps, we pass the team_name and reconstruct the Celery app here.
"""
key, args, queue, team_name = task_tuple

# Reconstruct the Celery app from configuration, which may or may not be team-specific.
# ExecutorConf wraps config access to automatically use team-specific config where present.
if TYPE_CHECKING:
_conf: ExecutorConf | AirflowConfigParser
# Check if Airflow version is greater than or equal to 3.2 to import ExecutorConf
if AIRFLOW_V_3_0_PLUS:
from airflow.executors.base_executor import ExecutorConf

_conf = ExecutorConf(team_name)
else:
# Airflow <3.2 ExecutorConf doesn't exist (at least not with the required attributes), fall back to global conf
_conf = conf

# Create the Celery app with the correct configuration
celery_app = create_celery_app(_conf)

if AIRFLOW_V_3_0_PLUS:
# Get the task from the app
task_to_run = celery_app.tasks["execute_workload"]
if TYPE_CHECKING:
assert isinstance(args, workloads.BaseWorkload)
args = (args.model_dump_json(),)
else:
# Get the task from the app
task_to_run = celery_app.tasks["execute_command"]
args = [args] # type: ignore[list-item]
try:
with timeout(seconds=OPERATION_TIMEOUT):
Expand Down Expand Up @@ -324,18 +383,19 @@ class BulkStateFetcher(LoggingMixin):
Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually.
"""

def __init__(self, sync_parallelism: int):
def __init__(self, sync_parallelism: int, celery_app: Celery | None = None):
super().__init__()
self._sync_parallelism = sync_parallelism
self.celery_app = celery_app or app # Use provided app or fall back to module-level app

def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) -> set[str]:
return {a.task_id for a in async_tasks}

def get_many(self, async_results: Collection[AsyncResult]) -> Mapping[str, EventBufferValueType]:
"""Get status for many Celery tasks using the best method available."""
if isinstance(app.backend, BaseKeyValueStoreBackend):
if isinstance(self.celery_app.backend, BaseKeyValueStoreBackend):
result = self._get_many_from_kv_backend(async_results)
elif isinstance(app.backend, DatabaseBackend):
elif isinstance(self.celery_app.backend, DatabaseBackend):
result = self._get_many_from_db_backend(async_results)
else:
result = self._get_many_using_multiprocessing(async_results)
Expand All @@ -346,17 +406,17 @@ def _get_many_from_kv_backend(
self, async_tasks: Collection[AsyncResult]
) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
keys = [app.backend.get_key_for_task(k) for k in task_ids]
values = app.backend.mget(keys)
task_results = [app.backend.decode_result(v) for v in values if v]
keys = [self.celery_app.backend.get_key_for_task(k) for k in task_ids]
values = self.celery_app.backend.mget(keys)
task_results = [self.celery_app.backend.decode_result(v) for v in values if v]
task_results_by_task_id = {task_result["task_id"]: task_result for task_result in task_results}

return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)

@retry
def _query_task_cls_from_db_backend(self, task_ids: set[str], **kwargs):
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
session = self.celery_app.backend.ResultSession()
task_cls = getattr(self.celery_app.backend, "task_cls", TaskDb)
with session_cleanup(session):
return session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()

Expand All @@ -365,7 +425,7 @@ def _get_many_from_db_backend(
) -> Mapping[str, EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
tasks = self._query_task_cls_from_db_backend(task_ids)
task_results = [app.backend.meta_from_decoded(task.to_dict()) for task in tasks]
task_results = [self.celery_app.backend.meta_from_decoded(task.to_dict()) for task in tasks]
task_results_by_task_id = {task_result["task_id"]: task_result for task_result in task_results}

return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)
Expand Down
Loading
Loading