Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions airflow-core/src/airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _run_worker(
input: SimpleQueue[workloads.All | None],
output: Queue[TaskInstanceStateType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
team_conf,
):
import signal

Expand All @@ -67,8 +68,11 @@ def _run_worker(
log = structlog.get_logger(logger_name)
log.info("Worker starting up pid=%d", os.getpid())

# Create team suffix for process title
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""

while True:
setproctitle("airflow worker -- LocalExecutor: <idle>", log)
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: <idle>", log)
try:
workload = input.get()
except EOFError:
Expand Down Expand Up @@ -96,27 +100,29 @@ def _run_worker(
raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}")

try:
_execute_work(log, workload)
_execute_work(log, workload, team_conf)

output.put((key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("uhoh")
output.put((key, TaskInstanceState.FAILED, e))


def _execute_work(log: Logger, workload: workloads.ExecuteTask) -> None:
def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None:
"""
Execute command received and stores result state in queue.

:param key: the key to identify the task instance
:param command: the command to execute
:param log: Logger instance
:param workload: The workload to execute
:param team_conf: Team-specific executor configuration
"""
from airflow.configuration import conf
from airflow.sdk.execution_time.supervisor import supervise

setproctitle(f"airflow worker -- LocalExecutor: {workload.ti.id}", log)
# Create team suffix for process title
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: {workload.ti.id}", log)

base_url = conf.get("api", "base_url", fallback="/")
base_url = team_conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
Expand All @@ -130,7 +136,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask) -> None:
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)

Expand All @@ -154,6 +160,19 @@ class LocalExecutor(BaseExecutor):
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
# Airflow.
# Can be removed when minimum supported provider version is equal to the version of core airflow
# which introduces multi-team configuration.
if not hasattr(self, "conf"):
from airflow.configuration import conf

self.conf = conf

def start(self) -> None:
"""Start the executor."""
# We delay opening these queues until the start method mostly for unit tests. ExecutorLoader caches
Expand Down Expand Up @@ -212,6 +231,7 @@ def _spawn_worker(self):
"input": self.activity_queue,
"output": self.result_queue,
"unread_messages": self._unread_messages,
"team_conf": self.conf,
},
)
p.start()
Expand Down
91 changes: 90 additions & 1 deletion airflow-core/tests/unit/executors/test_local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def test_clean_stop_on_signal(self):
@mock.patch("airflow.sdk.execution_time.supervisor.supervise")
def test_execution_api_server_url_config(self, mock_supervise, conf_values, expected_server):
"""Test that execution_api_server_url is correctly configured with fallback"""
from airflow.executors.base_executor import ExecutorConf

with conf_vars(conf_values):
_execute_work(log=mock.ANY, workload=mock.MagicMock())
team_conf = ExecutorConf(team_name=None)
_execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf)

mock_supervise.assert_called_with(
ti=mock.ANY,
Expand All @@ -232,3 +235,89 @@ def test_execution_api_server_url_config(self, mock_supervise, conf_values, expe
server=expected_server,
log_path=mock.ANY,
)

@mock.patch("airflow.sdk.execution_time.supervisor.supervise")
def test_team_and_global_config_isolation(self, mock_supervise):
"""Test that team-specific and global executors use correct configurations side-by-side"""
from airflow.executors.base_executor import ExecutorConf

team_name = "ml_team"
team_server = "http://team-ml-server:8080/execution/"
default_server = "http://default-server/execution/"

# Set up global configuration
config_overrides = {
("api", "base_url"): "http://default-server",
("core", "execution_api_server_url"): default_server,
}

# Use environment variables for team-specific config
import os

team_env_key = f"AIRFLOW__{team_name.upper()}___CORE__EXECUTION_API_SERVER_URL"

with mock.patch.dict(os.environ, {team_env_key: team_server}):
with conf_vars(config_overrides):
# Test team-specific config
team_conf = ExecutorConf(team_name=team_name)
_execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=team_conf)

# Verify team-specific server URL was used
assert mock_supervise.call_count == 1
call_kwargs = mock_supervise.call_args[1]
assert call_kwargs["server"] == team_server

mock_supervise.reset_mock()

# Test global config (no team)
global_conf = ExecutorConf(team_name=None)
_execute_work(log=mock.ANY, workload=mock.MagicMock(), team_conf=global_conf)

# Verify default server URL was used
assert mock_supervise.call_count == 1
call_kwargs = mock_supervise.call_args[1]
assert call_kwargs["server"] == default_server

def test_multiple_team_executors_isolation(self):
"""Test that multiple team executors can coexist with isolated resources"""
team_a_executor = LocalExecutor(parallelism=2, team_name="team_a")
team_b_executor = LocalExecutor(parallelism=3, team_name="team_b")

team_a_executor.start()
team_b_executor.start()

try:
# Verify each executor has its own queues
assert team_a_executor.activity_queue is not team_b_executor.activity_queue
assert team_a_executor.result_queue is not team_b_executor.result_queue

# Verify each executor has its own workers dict
assert team_a_executor.workers is not team_b_executor.workers
assert len(team_a_executor.workers) == 2
assert len(team_b_executor.workers) == 3

# Verify each executor has its own unread_messages counter
assert team_a_executor._unread_messages is not team_b_executor._unread_messages

# Verify each has correct team config
assert team_a_executor.conf.team_name == "team_a"
assert team_b_executor.conf.team_name == "team_b"

finally:
team_a_executor.end()
team_b_executor.end()

def test_global_executor_without_team_name(self):
"""Test that global executor (no team) works correctly"""
executor = LocalExecutor(parallelism=2)

# Verify executor has conf but no team name
assert hasattr(executor, "conf")
assert executor.conf.team_name is None

executor.start()

# Verify workers were created
assert len(executor.workers) == 2

executor.end()
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def _get_option_from_defaults(
section: str,
issue_warning: bool = True,
extra_stacklevel: int = 0,
team_name: str | None = None,
**kwargs,
) -> str | ValueNotFound:
"""Get config option from default values."""
Expand Down