Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions airflow-core/src/airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,14 @@ def mask_secrets(self):
mask_secret_core(value)
mask_secret_sdk(value)

def _env_var_name(self, section: str, key: str) -> str:
return f"{ENV_VAR_PREFIX}{section.replace('.', '_').upper()}__{key.upper()}"

def _get_env_var_option(self, section: str, key: str):
# must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
env_var = self._env_var_name(section, key)
def _env_var_name(self, section: str, key: str, team_name: str | None = None) -> str:
team_component: str = f"{team_name.upper()}___" if team_name else ""
return f"{ENV_VAR_PREFIX}{team_component}{section.replace('.', '_').upper()}__{key.upper()}"

def _get_env_var_option(self, section: str, key: str, team_name: str | None = None):
# must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) OR for team based
# configuration must have the format AIRFLOW__{TEAM_NAME}___{SECTION}__{KEY}
env_var: str = self._env_var_name(section, key, team_name=team_name)
if env_var in os.environ:
return expand_env_var(os.environ[env_var])
# alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
Expand Down Expand Up @@ -982,6 +984,7 @@ def get( # type: ignore[misc]
suppress_warnings: bool = False,
lookup_from_deprecated: bool = True,
_extra_stacklevel: int = 0,
team_name: str | None = None,
**kwargs,
) -> str | None:
section = section.lower()
Expand Down Expand Up @@ -1044,6 +1047,7 @@ def get( # type: ignore[misc]
section,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
team_name=team_name,
)
if option is not None:
return option
Expand Down Expand Up @@ -1170,13 +1174,14 @@ def _get_environment_variables(
section: str,
issue_warning: bool = True,
extra_stacklevel: int = 0,
team_name: str | None = None,
) -> str | None:
option = self._get_env_var_option(section, key)
option = self._get_env_var_option(section, key, team_name=team_name)
if option is not None:
return option
if deprecated_section and deprecated_key:
with self.suppress_future_warnings():
option = self._get_env_var_option(deprecated_section, deprecated_key)
option = self._get_env_var_option(deprecated_section, deprecated_key, team_name=team_name)
if option is not None:
if issue_warning:
self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
Expand Down
23 changes: 23 additions & 0 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def can_try_again(self):
return True


class ExecutorConf:
"""
This class is used to fetch configuration for an executor for a particular team_name.

It wraps the implementation of the configuration.get() to look for the particular section and key
prefixed with the team_name. This makes it easy for child classes (i.e. concrete executors) to fetch
configuration values for a particular team_name without having to worry about passing through the
team_name for every call to get configuration.

Currently config only supports environment variables for team specific configuration.
"""

def __init__(self, team_name: str | None = None) -> None:
self.team_name: str | None = team_name

def get(self, *args, **kwargs) -> str | None:
return conf.get(*args, **kwargs, team_name=self.team_name)

def getboolean(self, *args, **kwargs) -> bool:
return conf.getboolean(*args, **kwargs, team_name=self.team_name)


class BaseExecutor(LoggingMixin):
"""
Base class to inherit for concrete executors such as Celery, Kubernetes, Local, etc.
Expand Down Expand Up @@ -150,6 +172,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
self.conf = ExecutorConf(team_name)

if self.parallelism <= 0:
raise ValueError("parallelism is set to 0 or lower")
Expand Down
11 changes: 11 additions & 0 deletions airflow-core/tests/unit/core/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ def test_env_var_config(self):

assert conf.has_option("testsection", "testkey")

def test_env_team(self):
with patch(
"os.environ",
{
"AIRFLOW__CELERY__RESULT_BACKEND": "FOO",
"AIRFLOW__UNIT_TEST_TEAM___CELERY__RESULT_BACKEND": "BAR",
},
):
assert conf.get("celery", "result_backend") == "FOO"
assert conf.get("celery", "result_backend", team_name="unit_test_team") == "BAR"

@conf_vars({("core", "percent"): "with%%inside"})
def test_conf_as_dict(self):
cfg_dict = conf.as_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from botocore.exceptions import ClientError, NoCredentialsError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema
Expand Down Expand Up @@ -98,13 +97,6 @@ class AwsEcsExecutor(BaseExecutor):
Airflow TaskInstance's executor_config.
"""

# Maximum number of retries to run an ECS task.
MAX_RUN_TASK_ATTEMPTS = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
)

# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99

Expand All @@ -118,15 +110,32 @@ def __init__(self, *args, **kwargs):
self.active_workers: EcsTaskCollection = EcsTaskCollection()
self.pending_tasks: deque = deque()

self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
# Airflow.
# Can be removed when minimum supported provider version is equal to the version of core airflow
# which introduces multi-team configuration.
if not hasattr(self, "conf"):
from airflow.configuration import conf

self.conf = conf

self.cluster = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
self.container_name = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
self.attempts_since_last_successful_connection = 0

self.load_ecs_connection(check_connection=False)
self.IS_BOTO_CONNECTION_HEALTHY = False

self.run_task_kwargs = self._load_run_kwargs()

# Maximum number of retries to run an ECS task.
self.max_run_task_attempts = self.conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
)

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

Expand Down Expand Up @@ -154,7 +163,7 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
check_health = conf.getboolean(
check_health = self.conf.getboolean(
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
)

Expand Down Expand Up @@ -218,12 +227,12 @@ def check_health(self):

def load_ecs_connection(self, check_connection: bool = True):
self.log.info("Loading Connection information")
aws_conn_id = conf.get(
aws_conn_id = self.conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
)
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None)
region_name = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None)
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.attempts_since_last_successful_connection += 1
self.last_connection_reload = timezone.utcnow()
Expand Down Expand Up @@ -340,13 +349,13 @@ def __handle_failed_task(self, task_arn: str, reason: str):
queue = task_info.queue
exec_info = task_info.config
failure_count = self.active_workers.failure_count_by_key(task_key)
if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
if int(failure_count) < int(self.max_run_task_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
self.__class__.MAX_RUN_TASK_ATTEMPTS,
self.max_run_task_attempts,
task_arn,
)
self.pending_tasks.append(
Expand Down Expand Up @@ -416,8 +425,8 @@ def attempt_task_runs(self):
failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])

if failure_reasons:
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
# Make sure the number of attempts does not exceed max_run_task_attempts
if int(attempt_number) < int(self.max_run_task_attempts):
ecs_task.attempt_number += 1
ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
Expand Down Expand Up @@ -545,7 +554,7 @@ def terminate(self):
def _load_run_kwargs(self) -> dict:
from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config import build_task_kwargs

ecs_executor_run_task_kwargs = build_task_kwargs()
ecs_executor_run_task_kwargs = build_task_kwargs(self.conf)

try:
self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import json
from json import JSONDecodeError

from airflow.configuration import conf
from airflow.providers.amazon.aws.executors.ecs.utils import (
CONFIG_GROUP_NAME,
ECS_LAUNCH_TYPE_EC2,
Expand All @@ -46,23 +45,27 @@
from airflow.utils.helpers import prune_dict


def _fetch_templated_kwargs() -> dict[str, str]:
run_task_kwargs_value = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.RUN_TASK_KWARGS, fallback=dict())
def _fetch_templated_kwargs(conf) -> dict[str, str]:
run_task_kwargs_value = conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.RUN_TASK_KWARGS,
fallback=dict(),
)
return json.loads(str(run_task_kwargs_value))


def _fetch_config_values() -> dict[str, str]:
def _fetch_config_values(conf) -> dict[str, str]:
return prune_dict(
{key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in RunTaskKwargsConfigKeys()}
)


def build_task_kwargs() -> dict:
def build_task_kwargs(conf) -> dict:
all_config_keys = AllEcsConfigKeys()
# This will put some kwargs at the root of the dictionary that do NOT belong there. However,
# the code below expects them to be there and will rearrange them as necessary.
task_kwargs = _fetch_config_values()
task_kwargs.update(_fetch_templated_kwargs())
task_kwargs = _fetch_config_values(conf)
task_kwargs.update(_fetch_templated_kwargs(conf))

has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs
has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY in task_kwargs
Expand Down
Loading